From 0ec2dec2b2f22ff2b30226fa424b35e1e8322cb7 Mon Sep 17 00:00:00 2001 From: lixinqi Date: Thu, 13 Feb 2025 07:24:55 +0000 Subject: [PATCH 01/43] abstract pass initial commit --- paddle/CMakeLists.txt | 1 + paddle/ap/CMakeLists.txt | 76 + paddle/ap/include/adt/adt.h | 566 +++ paddle/ap/include/adt/bfs_walker.h | 73 + paddle/ap/include/adt/topo_walker.h | 90 + paddle/ap/include/axpr/abstract_list.h | 119 + paddle/ap/include/axpr/adt.h | 29 + paddle/ap/include/axpr/anf_expr.h | 94 + paddle/ap/include/axpr/anf_expr_builder.h | 61 + paddle/ap/include/axpr/anf_expr_helper.h | 214 ++ paddle/ap/include/axpr/anf_expr_util.h | 760 ++++ paddle/ap/include/axpr/atomic.h | 74 + paddle/ap/include/axpr/atomic_builder.h | 50 + paddle/ap/include/axpr/attr_map.h | 95 + .../ap/include/axpr/attr_map_method_class.h | 61 + paddle/ap/include/axpr/binary_func.h | 64 + paddle/ap/include/axpr/bool.h | 30 + paddle/ap/include/axpr/bool_helper.h | 59 + paddle/ap/include/axpr/bool_int_double.h | 41 + .../axpr/bool_int_double_arithmetic_util.h | 52 + .../ap/include/axpr/bool_int_double_helper.h | 37 + paddle/ap/include/axpr/bool_method_class.h | 135 + .../ap/include/axpr/builtin_class_instance.h | 111 + .../builtin_class_instance_method_class.h | 241 ++ paddle/ap/include/axpr/builtin_classes.h | 26 + paddle/ap/include/axpr/builtin_environment.h | 50 + paddle/ap/include/axpr/builtin_frame_util.h | 60 + .../ap/include/axpr/builtin_func_name_mgr.h | 75 + paddle/ap/include/axpr/builtin_func_type.h | 74 + .../axpr/builtin_func_type_method_class.h | 47 + paddle/ap/include/axpr/builtin_functions.h | 90 + .../axpr/builtin_high_order_func_type.h | 17 + ...uiltin_high_order_func_type_method_class.h | 42 + .../axpr/builtin_serializable_attr_map.h | 31 + ...iltin_serializable_attr_map_method_class.h | 91 + ...tin_serializable_attr_map_to_axpr_helper.h | 140 + paddle/ap/include/axpr/builtin_symbol.h | 279 ++ .../axpr/builtin_symbol_method_class.h | 45 + paddle/ap/include/axpr/call_environment.h | 67 + paddle/ap/include/axpr/callable_helper.h | 53 + paddle/ap/include/axpr/class_attrs.h | 36 + paddle/ap/include/axpr/class_attrs_helper.h | 56 + paddle/ap/include/axpr/class_instance.h | 58 + .../axpr/class_instance_method_class.h | 202 ++ paddle/ap/include/axpr/class_ops.h | 36 + paddle/ap/include/axpr/closure.h | 49 + paddle/ap/include/axpr/closure_method_class.h | 55 + .../include/axpr/const_global_environment.h | 67 + paddle/ap/include/axpr/constants.h | 17 + paddle/ap/include/axpr/continuation.h | 46 + .../include/axpr/continuation_method_class.h | 36 + paddle/ap/include/axpr/core_expr.h | 177 + paddle/ap/include/axpr/core_expr_builder.h | 38 + paddle/ap/include/axpr/cps_interpreter.h | 649 ++++ paddle/ap/include/axpr/data_type.h | 141 + .../ap/include/axpr/data_type_method_class.h | 138 + paddle/ap/include/axpr/data_type_util.h | 56 + paddle/ap/include/axpr/data_value.h | 109 + .../ap/include/axpr/data_value_method_class.h | 379 ++ paddle/ap/include/axpr/data_value_util.h | 115 + paddle/ap/include/axpr/dim_expr.h | 28 + .../ap/include/axpr/dim_expr_method_class.h | 102 + paddle/ap/include/axpr/environment.h | 45 + paddle/ap/include/axpr/error.h | 36 + .../ap/include/axpr/exception_method_class.h | 61 + paddle/ap/include/axpr/float.h | 31 + paddle/ap/include/axpr/float_method_class.h | 133 + paddle/ap/include/axpr/frame.h | 27 + paddle/ap/include/axpr/function.h | 51 + .../ap/include/axpr/function_method_class.h | 61 + paddle/ap/include/axpr/global_environment.h | 72 + paddle/ap/include/axpr/hash.h | 48 + paddle/ap/include/axpr/instance_attrs.h | 30 + paddle/ap/include/axpr/int.h | 31 + paddle/ap/include/axpr/int_data_type.h | 25 + paddle/ap/include/axpr/int_method_class.h | 133 + paddle/ap/include/axpr/interpreter.h | 43 + paddle/ap/include/axpr/interpreter_base.h | 53 + paddle/ap/include/axpr/lambda_expr_builder.h | 346 ++ paddle/ap/include/axpr/list.h | 30 + paddle/ap/include/axpr/list_method_class.h | 147 + paddle/ap/include/axpr/method.h | 43 + paddle/ap/include/axpr/method_class.h | 483 +++ paddle/ap/include/axpr/method_method_class.h | 57 + paddle/ap/include/axpr/module_mgr.h | 199 + paddle/ap/include/axpr/module_mgr_helper.h | 56 + .../include/axpr/mutable_global_environment.h | 100 + paddle/ap/include/axpr/mutable_list.h | 38 + .../include/axpr/mutable_list_method_class.h | 180 + paddle/ap/include/axpr/mutable_ordered_dict.h | 42 + .../axpr/mutable_ordered_dict_method_class.h | 184 + paddle/ap/include/axpr/naive_class_ops.h | 61 + paddle/ap/include/axpr/nothing.h | 30 + paddle/ap/include/axpr/nothing_method_class.h | 54 + paddle/ap/include/axpr/ordered_dict.h | 107 + .../include/axpr/ordered_dict_method_class.h | 152 + paddle/ap/include/axpr/packed_args.h | 57 + .../include/axpr/packed_args_method_class.h | 54 + paddle/ap/include/axpr/pointer_type.h | 81 + .../include/axpr/pointer_type_method_class.h | 136 + paddle/ap/include/axpr/pointer_type_util.h | 102 + paddle/ap/include/axpr/pointer_value.h | 49 + .../include/axpr/pointer_value_method_class.h | 104 + paddle/ap/include/axpr/s_expr.h | 61 + paddle/ap/include/axpr/serializable_list.h | 31 + .../axpr/serializable_list_method_class.h | 73 + paddle/ap/include/axpr/serializable_value.h | 147 + .../include/axpr/serializable_value_helper.h | 269 ++ paddle/ap/include/axpr/starred.h | 42 + paddle/ap/include/axpr/starred_method_class.h | 31 + paddle/ap/include/axpr/string.h | 31 + paddle/ap/include/axpr/string_method_class.h | 133 + paddle/ap/include/axpr/string_util.h | 112 + paddle/ap/include/axpr/to_string.h | 58 + paddle/ap/include/axpr/type.h | 75 + paddle/ap/include/axpr/type_method_class.h | 80 + paddle/ap/include/axpr/type_util.h | 75 + paddle/ap/include/axpr/unary_func.h | 53 + paddle/ap/include/axpr/value.h | 193 + paddle/ap/include/axpr/value_method_class.h | 44 + paddle/ap/include/code_gen/arg_source_ctx.h | 206 ++ .../ap/include/code_gen/arg_source_helper.h | 361 ++ paddle/ap/include/code_gen/arg_source_maker.h | 168 + .../ap/include/code_gen/builtin_frame_util.h | 54 + paddle/ap/include/code_gen/code_gen_ctx.h | 50 + .../code_gen/code_gen_ctx_method_class.h | 290 ++ paddle/ap/include/code_gen/code_gen_result.h | 40 + .../code_gen/code_gen_result_method_class.h | 25 + .../ap/include/code_gen/cuda_code_gen_util.h | 32 + .../include/code_gen/dim_expr_kernel_arg_id.h | 53 + .../dim_expr_kernel_arg_id_method_class.h | 88 + .../in_tensor_data_ptr_kernel_arg_id.h | 67 + ...nsor_data_ptr_kernel_arg_id_method_class.h | 86 + paddle/ap/include/code_gen/ir_op.h | 55 + paddle/ap/include/code_gen/kernel_arg.h | 36 + paddle/ap/include/code_gen/kernel_arg_id.h | 75 + .../include/code_gen/kernel_arg_id_helper.h | 87 + .../ap/include/code_gen/loop_anchor_flags.h | 25 + .../code_gen/matched_result_pattern_helper.h | 235 ++ paddle/ap/include/code_gen/op_code_gen_ctx.h | 44 + paddle/ap/include/code_gen/op_cuda_gen_impl.h | 34 + .../out_tensor_data_ptr_kernel_arg_id.h | 67 + ...nsor_data_ptr_kernel_arg_id_method_class.h | 86 + paddle/ap/include/code_gen/value.h | 37 + .../ap/include/code_gen/value_method_class.h | 27 + paddle/ap/include/code_module/adt.h | 32 + .../code_module/api_wrapper_project_maker.h | 254 ++ paddle/ap/include/code_module/arg_type.h | 82 + .../include/code_module/builtin_frame_util.h | 51 + paddle/ap/include/code_module/code_module.h | 37 + .../code_module/code_module_method_class.h | 28 + paddle/ap/include/code_module/data_type.h | 39 + paddle/ap/include/code_module/directory.h | 31 + .../code_module/directory_method_class.h | 27 + paddle/ap/include/code_module/file.h | 48 + paddle/ap/include/code_module/file_content.h | 30 + .../code_module/file_content_method_class.h | 26 + paddle/ap/include/code_module/func_declare.h | 38 + .../code_module/func_declare_method_class.h | 26 + .../code_module/module_compile_helper.h | 86 + .../code_module/module_to_axpr_helper.h | 154 + paddle/ap/include/code_module/package.h | 37 + .../code_module/package_method_class.h | 27 + paddle/ap/include/code_module/project.h | 37 + .../code_module/project_compile_helper.h | 126 + .../code_module/project_method_class.h | 31 + paddle/ap/include/code_module/rt_module.h | 27 + paddle/ap/include/code_module/soft_link.h | 30 + .../code_module/soft_link_method_class.h | 26 + paddle/ap/include/code_module/source_code.h | 42 + paddle/ap/include/code_module/value.h | 29 + .../include/code_module/value_method_class.h | 19 + paddle/ap/include/common/unique_id.h | 27 + paddle/ap/include/drr/builtin_frame_util.h | 42 + paddle/ap/include/drr/drr_ctx.h | 64 + paddle/ap/include/drr/drr_ctx_method_class.h | 34 + paddle/ap/include/drr/drr_graph_descriptor.h | 442 +++ paddle/ap/include/drr/drr_interpreter.h | 58 + paddle/ap/include/drr/drr_node_descriptor.h | 129 + paddle/ap/include/drr/drr_pass_type.h | 39 + paddle/ap/include/drr/drr_pass_type_helper.h | 41 + paddle/ap/include/drr/drr_value.h | 80 + paddle/ap/include/drr/drr_value_helper.h | 91 + paddle/ap/include/drr/ir_op.h | 63 + paddle/ap/include/drr/ir_value.h | 55 + paddle/ap/include/drr/native_ir_op.h | 58 + paddle/ap/include/drr/native_ir_op_declare.h | 71 + .../drr/native_ir_op_declare_method_class.h | 32 + .../include/drr/native_ir_op_method_class.h | 28 + paddle/ap/include/drr/native_ir_op_operand.h | 40 + paddle/ap/include/drr/native_ir_op_result.h | 40 + paddle/ap/include/drr/native_ir_value.h | 72 + .../drr/native_ir_value_method_class.h | 33 + paddle/ap/include/drr/node.h | 62 + paddle/ap/include/drr/op_pattern_ctx.h | 71 + .../drr/op_tensor_pattern_ctx_helper.h | 364 ++ paddle/ap/include/drr/opt_packed_ir_op.h | 58 + .../ap/include/drr/opt_packed_ir_op_declare.h | 78 + .../opt_packed_ir_op_declare_method_class.h | 29 + .../drr/opt_packed_ir_op_method_class.h | 28 + .../ap/include/drr/opt_packed_ir_op_operand.h | 40 + .../ap/include/drr/opt_packed_ir_op_result.h | 40 + paddle/ap/include/drr/packed_ir_op.h | 58 + paddle/ap/include/drr/packed_ir_op_declare.h | 93 + .../include/drr/packed_ir_op_declare_data.h | 31 + .../drr/packed_ir_op_declare_method_class.h | 32 + .../include/drr/packed_ir_op_method_class.h | 28 + paddle/ap/include/drr/packed_ir_op_operand.h | 40 + paddle/ap/include/drr/packed_ir_op_result.h | 40 + paddle/ap/include/drr/packed_ir_value.h | 97 + .../drr/packed_ir_value_method_class.h | 35 + .../drr/res_ptn_op_pattern_ctx_method_class.h | 40 + .../drr/res_ptn_packed_ir_op_declare_data.h | 35 + .../res_ptn_tensor_pattern_ctx_method_class.h | 31 + ...es_ptn_unbound_native_ir_op_method_class.h | 34 + ...es_ptn_unbound_packed_ir_op_method_class.h | 34 + .../include/drr/res_ptn_valid_out_ir_value.h | 58 + paddle/ap/include/drr/result_pattern_ctx.h | 55 + .../drr/result_pattern_ctx_method_class.h | 28 + paddle/ap/include/drr/result_pattern_helper.h | 134 + paddle/ap/include/drr/source_pattern_ctx.h | 51 + .../drr/source_pattern_ctx_method_class.h | 29 + .../drr/src_ptn_op_pattern_ctx_method_class.h | 38 + .../drr/src_ptn_packed_ir_op_declare_data.h | 34 + .../src_ptn_tensor_pattern_ctx_method_class.h | 32 + ...rc_ptn_unbound_native_ir_op_method_class.h | 34 + ...rc_ptn_unbound_packed_ir_op_method_class.h | 38 + .../include/drr/src_ptn_valid_in_ir_value.h | 43 + .../include/drr/src_ptn_valid_out_ir_value.h | 41 + paddle/ap/include/drr/tags.h | 40 + paddle/ap/include/drr/tensor_pattern_ctx.h | 70 + paddle/ap/include/drr/topo_kind.h | 34 + paddle/ap/include/drr/type.h | 22 + paddle/ap/include/drr/unbound_ir_value.h | 53 + .../drr/unbound_ir_value_method_class.h | 33 + paddle/ap/include/drr/unbound_native_ir_op.h | 68 + .../ap/include/drr/unbound_opt_packed_ir_op.h | 53 + .../unbound_opt_packed_ir_op_method_class.h | 36 + paddle/ap/include/drr/unbound_packed_ir_op.h | 68 + .../ap/include/drr/unbound_packed_ir_value.h | 52 + .../unbound_packed_ir_value_method_class.h | 29 + paddle/ap/include/drr/value.h | 27 + paddle/ap/include/drr/value_method_class.h | 38 + paddle/ap/include/env/ap_path.h | 42 + paddle/ap/include/fs/fs.h | 64 + paddle/ap/include/graph/adt.h | 20 + paddle/ap/include/graph/graph_descriptor.h | 49 + paddle/ap/include/graph/graph_helper.h | 178 + paddle/ap/include/graph/node.h | 86 + paddle/ap/include/graph/node_arena.h | 179 + paddle/ap/include/graph/node_descriptor.h | 36 + paddle/ap/include/graph/node_list.h | 113 + paddle/ap/include/graph/node_topo_cstr.h | 180 + paddle/ap/include/graph/tags.h | 28 + .../include/index_expr/builtin_frame_util.h | 44 + .../index_expr/dim_expr_cuda_code_generator.h | 163 + paddle/ap/include/index_expr/index_closure.h | 103 + paddle/ap/include/index_expr/index_expr.h | 171 + .../index_expr/index_expr_builtin_functions.h | 441 +++ .../index_expr/index_expr_interpreter.h | 47 + .../index_expr/index_expr_method_class.h | 46 + .../ap/include/index_expr/index_expr_util.h | 168 + .../ap/include/index_expr/index_tuple_expr.h | 196 + .../index_tuple_expr_cuda_code_generator.h | 97 + .../index_tuple_expr_method_class.h | 89 + .../op_index_tuple_expr_signature.h | 44 + ..._index_tuple_expr_signature_method_class.h | 91 + paddle/ap/include/index_expr/op_signature.h | 78 + paddle/ap/include/index_expr/slice.h | 47 + .../include/index_expr/slice_method_class.h | 45 + .../index_expr/valid_index_expr_builder.h | 248 ++ paddle/ap/include/index_expr/value.h | 32 + .../include/index_expr/value_method_class.h | 22 + paddle/ap/include/ir_match/graph_match_ctx.h | 280 ++ paddle/ap/include/ir_match/graph_matcher.h | 120 + paddle/ap/include/ir_match/ir_match_ctx.h | 50 + .../include/ir_match/native_or_ref_ir_value.h | 50 + paddle/ap/include/ir_match/op_match_ctx.h | 39 + .../ir_match/op_match_ctx_method_class.h | 120 + paddle/ap/include/ir_match/ref_match_ctx.h | 46 + paddle/ap/include/ir_match/ref_node_info.h | 47 + paddle/ap/include/ir_match/tags.h | 21 + paddle/ap/include/ir_match/tensor_match_ctx.h | 41 + .../ir_match/tensor_match_ctx_method_class.h | 131 + paddle/ap/include/ir_match/topo_match_ctx.h | 236 ++ paddle/ap/include/ir_match/topo_matcher.h | 398 ++ .../include/kernel_dispatch/ap_unary_kernel.h | 41 + paddle/ap/include/kernel_dispatch/arg_value.h | 36 + .../kernel_dispatch/builtin_frame_util.h | 41 + .../ap/include/kernel_dispatch/const_tensor.h | 73 + .../const_tensor_method_class.h | 110 + .../ap/include/kernel_dispatch/device_ctx.h | 36 + .../kernel_dispatch/device_ctx_method_class.h | 25 + .../ap/include/kernel_dispatch/dispatch_ctx.h | 46 + .../dispatch_ctx_method_class.h | 246 ++ .../kernel_dispatch/dispatch_raw_ctx.h | 73 + .../include/kernel_dispatch/mutable_tensor.h | 68 + .../mutable_tensor_method_class.h | 110 + .../ap/include/kernel_dispatch/typed_buffer.h | 40 + paddle/ap/include/kernel_dispatch/value.h | 32 + paddle/ap/include/memory/circlable_ref.h | 63 + paddle/ap/include/memory/circlable_ref_impl.h | 62 + .../include/memory/circlable_ref_impl_base.h | 67 + paddle/ap/include/memory/circlable_ref_list.h | 76 + .../include/memory/circlable_ref_list_base.h | 43 + paddle/ap/include/memory/guard.h | 35 + paddle/ap/include/paddle/builtin_frame_util.h | 33 + .../ap/include/paddle/const_meta_tensor_ptr.h | 36 + .../const_meta_tensor_ptr_method_class.h | 84 + ...r_const_meta_tensor_ptr_ptr_method_class.h | 79 + paddle/ap/include/paddle/ddim.h | 36 + paddle/ap/include/paddle/ddim_method_class.h | 79 + paddle/ap/include/paddle/indexed_ir_graph.h | 60 + .../ap/include/paddle/indexed_ir_graph_util.h | 224 ++ paddle/ap/include/paddle/indexed_ir_node.h | 105 + paddle/ap/include/paddle/meta_tensor_ptr.h | 36 + .../paddle/meta_tensor_ptr_method_class.h | 149 + .../ap/include/paddle/op_cuda_code_gen_impl.h | 1018 ++++++ paddle/ap/include/paddle/pass/ap_drr_helper.h | 56 + .../paddle/pass/ap_kernel_define_helper.h | 44 + .../paddle/pass/ap_lower_fusion_op_pass.h | 56 + .../include/paddle/pass/ap_registry_helper.h | 26 + paddle/ap/include/paddle/pass/ir_helper.h | 19 + .../paddle/pass/ir_helper_method_class.h | 42 + .../include/paddle/phi/ap_infer_meta_helper.h | 34 + paddle/ap/include/paddle/phi/device_ctx.h | 41 + .../include/paddle/phi/kernel_define_helper.h | 33 + .../paddle/phi/kernel_dispatch_helper.h | 42 + paddle/ap/include/paddle/phi/place.h | 20 + .../include/paddle/phi/place_method_class.h | 103 + paddle/ap/include/paddle/phi/scalar_helper.h | 92 + .../ap/include/paddle/pir/attr_adt_type_id.h | 117 + paddle/ap/include/paddle/pir/attribute.h | 24 + .../paddle/pir/attribute_method_class.h | 264 ++ paddle/ap/include/paddle/pir/manual_op.h | 142 + paddle/ap/include/paddle/pir/op_dialect.h | 35 + ...packed_ir_op_inner_source_pattern_helper.h | 38 + paddle/ap/include/paddle/pir/pass.h | 30 + paddle/ap/include/paddle/pir/pass_manager.h | 30 + .../paddle/pir/pass_manager_method_class.h | 31 + .../ap/include/paddle/pir/pass_method_class.h | 27 + paddle/ap/include/paddle/pir/pir.h | 23 + .../ap/include/paddle/pir/pir_method_class.h | 32 + .../pir/pir_node_matched_src_ptn_ctx_helper.h | 62 + .../paddle/pir/pir_to_anf_expr_helper.h | 37 + paddle/ap/include/paddle/pir/program.h | 30 + .../include/paddle/pir/program_method_class.h | 27 + .../paddle/pir/shape_or_data_method_class.h | 33 + paddle/ap/include/paddle/pir/type.h | 19 + .../ap/include/paddle/pir/type_adt_type_id.h | 109 + .../ap/include/paddle/pir/type_method_class.h | 236 ++ .../ap/include/paddle/pir_graph_descriptor.h | 591 +++ paddle/ap/include/paddle/pir_node.h | 415 +++ .../ap/include/paddle/pir_node_descriptor.h | 213 ++ paddle/ap/include/paddle/pir_node_helper.h | 86 + .../ap/include/paddle/pir_node_method_class.h | 388 ++ paddle/ap/include/paddle/pir_util.h | 63 + ..._vector_meta_tensor_ptr_ptr_method_class.h | 79 + paddle/ap/include/preprocessor/preprocessor.h | 18 + .../abstract_drr_pass_registry_item.h | 32 + .../access_topo_drr_pass_registry_item.h | 33 + .../ap/include/registry/builtin_frame_util.h | 40 + .../registry/classic_drr_pass_registry_item.h | 32 + paddle/ap/include/registry/registry.h | 53 + paddle/ap/include/registry/registry_class.h | 143 + paddle/ap/include/registry/registry_mgr.h | 103 + .../ap/include/registry/registry_singleton.h | 80 + paddle/ap/include/registry/value.h | 30 + .../drr_node_attr_to_anf_expr_helper.h | 32 + .../reified_drr/matched_src_ptn_ctx_helper.h | 44 + .../reified_drr_pass_dump_helper.h | 41 + .../reified_drr/reified_res_ptn_axpr_maker.h | 45 + .../reified_drr/reified_src_ptn_axpr_maker.h | 45 + paddle/ap/include/rt_module/arg_value.h | 70 + paddle/ap/include/rt_module/dl_function.h | 53 + paddle/ap/include/rt_module/dl_handle.h | 34 + paddle/ap/include/rt_module/function.h | 35 + paddle/ap/include/rt_module/function_helper.h | 122 + .../include/rt_module/function_method_class.h | 46 + paddle/ap/include/rt_module/module.h | 32 + .../ap/include/rt_module/naive_dl_handler.h | 85 + paddle/ap/include/rt_module/naive_module.h | 63 + .../ap/include/rt_module/naive_module_maker.h | 151 + paddle/ap/src/axpr/anf_expr.cc | 133 + paddle/ap/src/axpr/builtin_functions.cc | 505 +++ paddle/ap/src/axpr/core_expr.cc | 81 + paddle/ap/src/axpr/exception_method_class.cc | 38 + paddle/ap/src/axpr/interpreter.cc | 39 + paddle/ap/src/axpr/s_expr.cc | 79 + .../code_gen/code_gen_result_method_class.cc | 89 + .../code_module/code_module_method_class.cc | 73 + .../src/code_module/directory_method_class.cc | 71 + .../code_module/file_content_method_class.cc | 46 + .../code_module/func_declare_method_class.cc | 82 + .../src/code_module/package_method_class.cc | 100 + .../src/code_module/project_method_class.cc | 97 + .../src/code_module/soft_link_method_class.cc | 47 + paddle/ap/src/drr/drr_ctx_method_class.cc | 233 ++ paddle/ap/src/drr/drr_interpreter.cc | 119 + .../drr/native_ir_op_declare_method_class.cc | 87 + .../ap/src/drr/native_ir_op_method_class.cc | 52 + .../src/drr/native_ir_value_method_class.cc | 150 + .../opt_packed_ir_op_declare_method_class.cc | 53 + .../src/drr/opt_packed_ir_op_method_class.cc | 53 + .../drr/packed_ir_op_declare_method_class.cc | 89 + .../ap/src/drr/packed_ir_op_method_class.cc | 53 + .../src/drr/packed_ir_value_method_class.cc | 168 + .../res_ptn_op_pattern_ctx_method_class.cc | 225 ++ ...res_ptn_tensor_pattern_ctx_method_class.cc | 126 + ...s_ptn_unbound_native_ir_op_method_class.cc | 176 + ...s_ptn_unbound_packed_ir_op_method_class.cc | 132 + .../drr/result_pattern_ctx_method_class.cc | 54 + .../drr/source_pattern_ctx_method_class.cc | 54 + .../src_ptn_op_pattern_ctx_method_class.cc | 203 ++ ...src_ptn_tensor_pattern_ctx_method_class.cc | 87 + ...c_ptn_unbound_native_ir_op_method_class.cc | 210 ++ ...c_ptn_unbound_packed_ir_op_method_class.cc | 307 ++ .../src/drr/unbound_ir_value_method_class.cc | 83 + .../unbound_opt_packed_ir_op_method_class.cc | 270 ++ .../unbound_packed_ir_value_method_class.cc | 54 + paddle/ap/src/index_expr/index_closure.cc | 100 + .../index_expr_builtin_functions.cc | 22 + paddle/ap/src/index_expr/index_expr_util.cc | 18 + .../index_expr/valid_index_expr_builder.cc | 22 + .../device_ctx_method_class.cc | 57 + paddle/ap/src/paddle/pass/ap_drr_helper.cc | 64 + .../paddle/pass/ap_kernel_define_helper.cc | 65 + .../paddle/pass/ap_lower_fusion_op_pass.cc | 3212 +++++++++++++++++ .../ap/src/paddle/pass/ap_registry_helper.cc | 36 + .../src/paddle/pass/ir_helper_method_class.cc | 390 ++ paddle/ap/src/paddle/pass/op_factory.cc | 216 ++ paddle/ap/src/paddle/pass/op_factory.h | 39 + .../ap/src/paddle/phi/ap_infer_meta_helper.cc | 95 + paddle/ap/src/paddle/phi/ap_unary_kernel.cc | 304 ++ .../ap/src/paddle/phi/kernel_define_helper.cc | 47 + .../src/paddle/phi/kernel_dispatch_helper.cc | 55 + .../src/paddle/pir/attribute_method_class.cc | 584 +++ paddle/ap/src/paddle/pir/manual_op.cc | 153 + paddle/ap/src/paddle/pir/op_dialect.cc | 39 + ...acked_ir_op_inner_source_pattern_helper.cc | 135 + .../paddle/pir/pass_manager_method_class.cc | 79 + paddle/ap/src/paddle/pir/pass_method_class.cc | 43 + paddle/ap/src/paddle/pir/pir_method_class.cc | 70 + .../pir_node_matched_src_ptn_ctx_helper.cc | 430 +++ .../src/paddle/pir/pir_to_anf_expr_helper.cc | 707 ++++ .../ap/src/paddle/pir/program_method_class.cc | 210 ++ .../paddle/pir/shape_or_data_method_class.cc | 41 + paddle/ap/src/paddle/pir/type_method_class.cc | 506 +++ .../reified_drr_pass_dump_helper.cc | 287 ++ .../reified_drr/reified_res_ptn_axpr_maker.cc | 275 ++ .../reified_drr/reified_src_ptn_axpr_maker.cc | 332 ++ .../operator/transforms/CMakeLists.txt | 3 +- .../operator/transforms/add_cinn_pass.cc | 54 +- paddle/common/adt_type_id.h | 4 +- paddle/common/flags.cc | 2 + paddle/phi/CMakeLists.txt | 3 +- paddle/phi/infermeta/multiary.cc | 19 + paddle/phi/infermeta/multiary.h | 9 + paddle/phi/kernels/gpu/ap_unary.cu | 96 + paddle/phi/ops/yaml/ops.yaml | 8 + paddle/pir/include/core/op_operand.h | 13 + paddle/pir/include/core/op_result.h | 4 +- python/setup.py.in | 3 +- 463 files changed, 46703 insertions(+), 18 deletions(-) create mode 100644 paddle/ap/CMakeLists.txt create mode 100644 paddle/ap/include/adt/adt.h create mode 100644 paddle/ap/include/adt/bfs_walker.h create mode 100644 paddle/ap/include/adt/topo_walker.h create mode 100644 paddle/ap/include/axpr/abstract_list.h create mode 100644 paddle/ap/include/axpr/adt.h create mode 100644 paddle/ap/include/axpr/anf_expr.h create mode 100644 paddle/ap/include/axpr/anf_expr_builder.h create mode 100644 paddle/ap/include/axpr/anf_expr_helper.h create mode 100644 paddle/ap/include/axpr/anf_expr_util.h create mode 100644 paddle/ap/include/axpr/atomic.h create mode 100644 paddle/ap/include/axpr/atomic_builder.h create mode 100644 paddle/ap/include/axpr/attr_map.h create mode 100644 paddle/ap/include/axpr/attr_map_method_class.h create mode 100644 paddle/ap/include/axpr/binary_func.h create mode 100644 paddle/ap/include/axpr/bool.h create mode 100644 paddle/ap/include/axpr/bool_helper.h create mode 100644 paddle/ap/include/axpr/bool_int_double.h create mode 100644 paddle/ap/include/axpr/bool_int_double_arithmetic_util.h create mode 100644 paddle/ap/include/axpr/bool_int_double_helper.h create mode 100644 paddle/ap/include/axpr/bool_method_class.h create mode 100644 paddle/ap/include/axpr/builtin_class_instance.h create mode 100644 paddle/ap/include/axpr/builtin_class_instance_method_class.h create mode 100644 paddle/ap/include/axpr/builtin_classes.h create mode 100644 paddle/ap/include/axpr/builtin_environment.h create mode 100644 paddle/ap/include/axpr/builtin_frame_util.h create mode 100644 paddle/ap/include/axpr/builtin_func_name_mgr.h create mode 100644 paddle/ap/include/axpr/builtin_func_type.h create mode 100644 paddle/ap/include/axpr/builtin_func_type_method_class.h create mode 100644 paddle/ap/include/axpr/builtin_functions.h create mode 100644 paddle/ap/include/axpr/builtin_high_order_func_type.h create mode 100644 paddle/ap/include/axpr/builtin_high_order_func_type_method_class.h create mode 100644 paddle/ap/include/axpr/builtin_serializable_attr_map.h create mode 100644 paddle/ap/include/axpr/builtin_serializable_attr_map_method_class.h create mode 100644 paddle/ap/include/axpr/builtin_serializable_attr_map_to_axpr_helper.h create mode 100644 paddle/ap/include/axpr/builtin_symbol.h create mode 100644 paddle/ap/include/axpr/builtin_symbol_method_class.h create mode 100644 paddle/ap/include/axpr/call_environment.h create mode 100644 paddle/ap/include/axpr/callable_helper.h create mode 100644 paddle/ap/include/axpr/class_attrs.h create mode 100644 paddle/ap/include/axpr/class_attrs_helper.h create mode 100644 paddle/ap/include/axpr/class_instance.h create mode 100644 paddle/ap/include/axpr/class_instance_method_class.h create mode 100644 paddle/ap/include/axpr/class_ops.h create mode 100644 paddle/ap/include/axpr/closure.h create mode 100644 paddle/ap/include/axpr/closure_method_class.h create mode 100644 paddle/ap/include/axpr/const_global_environment.h create mode 100644 paddle/ap/include/axpr/constants.h create mode 100644 paddle/ap/include/axpr/continuation.h create mode 100644 paddle/ap/include/axpr/continuation_method_class.h create mode 100644 paddle/ap/include/axpr/core_expr.h create mode 100644 paddle/ap/include/axpr/core_expr_builder.h create mode 100644 paddle/ap/include/axpr/cps_interpreter.h create mode 100644 paddle/ap/include/axpr/data_type.h create mode 100644 paddle/ap/include/axpr/data_type_method_class.h create mode 100644 paddle/ap/include/axpr/data_type_util.h create mode 100644 paddle/ap/include/axpr/data_value.h create mode 100644 paddle/ap/include/axpr/data_value_method_class.h create mode 100644 paddle/ap/include/axpr/data_value_util.h create mode 100644 paddle/ap/include/axpr/dim_expr.h create mode 100644 paddle/ap/include/axpr/dim_expr_method_class.h create mode 100644 paddle/ap/include/axpr/environment.h create mode 100644 paddle/ap/include/axpr/error.h create mode 100644 paddle/ap/include/axpr/exception_method_class.h create mode 100644 paddle/ap/include/axpr/float.h create mode 100644 paddle/ap/include/axpr/float_method_class.h create mode 100644 paddle/ap/include/axpr/frame.h create mode 100644 paddle/ap/include/axpr/function.h create mode 100644 paddle/ap/include/axpr/function_method_class.h create mode 100644 paddle/ap/include/axpr/global_environment.h create mode 100644 paddle/ap/include/axpr/hash.h create mode 100644 paddle/ap/include/axpr/instance_attrs.h create mode 100644 paddle/ap/include/axpr/int.h create mode 100644 paddle/ap/include/axpr/int_data_type.h create mode 100644 paddle/ap/include/axpr/int_method_class.h create mode 100644 paddle/ap/include/axpr/interpreter.h create mode 100644 paddle/ap/include/axpr/interpreter_base.h create mode 100644 paddle/ap/include/axpr/lambda_expr_builder.h create mode 100644 paddle/ap/include/axpr/list.h create mode 100644 paddle/ap/include/axpr/list_method_class.h create mode 100644 paddle/ap/include/axpr/method.h create mode 100644 paddle/ap/include/axpr/method_class.h create mode 100644 paddle/ap/include/axpr/method_method_class.h create mode 100644 paddle/ap/include/axpr/module_mgr.h create mode 100644 paddle/ap/include/axpr/module_mgr_helper.h create mode 100644 paddle/ap/include/axpr/mutable_global_environment.h create mode 100644 paddle/ap/include/axpr/mutable_list.h create mode 100644 paddle/ap/include/axpr/mutable_list_method_class.h create mode 100644 paddle/ap/include/axpr/mutable_ordered_dict.h create mode 100644 paddle/ap/include/axpr/mutable_ordered_dict_method_class.h create mode 100644 paddle/ap/include/axpr/naive_class_ops.h create mode 100644 paddle/ap/include/axpr/nothing.h create mode 100644 paddle/ap/include/axpr/nothing_method_class.h create mode 100644 paddle/ap/include/axpr/ordered_dict.h create mode 100644 paddle/ap/include/axpr/ordered_dict_method_class.h create mode 100644 paddle/ap/include/axpr/packed_args.h create mode 100644 paddle/ap/include/axpr/packed_args_method_class.h create mode 100644 paddle/ap/include/axpr/pointer_type.h create mode 100644 paddle/ap/include/axpr/pointer_type_method_class.h create mode 100644 paddle/ap/include/axpr/pointer_type_util.h create mode 100644 paddle/ap/include/axpr/pointer_value.h create mode 100644 paddle/ap/include/axpr/pointer_value_method_class.h create mode 100644 paddle/ap/include/axpr/s_expr.h create mode 100644 paddle/ap/include/axpr/serializable_list.h create mode 100644 paddle/ap/include/axpr/serializable_list_method_class.h create mode 100644 paddle/ap/include/axpr/serializable_value.h create mode 100644 paddle/ap/include/axpr/serializable_value_helper.h create mode 100644 paddle/ap/include/axpr/starred.h create mode 100644 paddle/ap/include/axpr/starred_method_class.h create mode 100644 paddle/ap/include/axpr/string.h create mode 100644 paddle/ap/include/axpr/string_method_class.h create mode 100644 paddle/ap/include/axpr/string_util.h create mode 100644 paddle/ap/include/axpr/to_string.h create mode 100644 paddle/ap/include/axpr/type.h create mode 100644 paddle/ap/include/axpr/type_method_class.h create mode 100644 paddle/ap/include/axpr/type_util.h create mode 100644 paddle/ap/include/axpr/unary_func.h create mode 100644 paddle/ap/include/axpr/value.h create mode 100644 paddle/ap/include/axpr/value_method_class.h create mode 100644 paddle/ap/include/code_gen/arg_source_ctx.h create mode 100644 paddle/ap/include/code_gen/arg_source_helper.h create mode 100644 paddle/ap/include/code_gen/arg_source_maker.h create mode 100644 paddle/ap/include/code_gen/builtin_frame_util.h create mode 100644 paddle/ap/include/code_gen/code_gen_ctx.h create mode 100644 paddle/ap/include/code_gen/code_gen_ctx_method_class.h create mode 100644 paddle/ap/include/code_gen/code_gen_result.h create mode 100644 paddle/ap/include/code_gen/code_gen_result_method_class.h create mode 100644 paddle/ap/include/code_gen/cuda_code_gen_util.h create mode 100644 paddle/ap/include/code_gen/dim_expr_kernel_arg_id.h create mode 100644 paddle/ap/include/code_gen/dim_expr_kernel_arg_id_method_class.h create mode 100644 paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id.h create mode 100644 paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id_method_class.h create mode 100644 paddle/ap/include/code_gen/ir_op.h create mode 100644 paddle/ap/include/code_gen/kernel_arg.h create mode 100644 paddle/ap/include/code_gen/kernel_arg_id.h create mode 100644 paddle/ap/include/code_gen/kernel_arg_id_helper.h create mode 100644 paddle/ap/include/code_gen/loop_anchor_flags.h create mode 100644 paddle/ap/include/code_gen/matched_result_pattern_helper.h create mode 100644 paddle/ap/include/code_gen/op_code_gen_ctx.h create mode 100644 paddle/ap/include/code_gen/op_cuda_gen_impl.h create mode 100644 paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id.h create mode 100644 paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id_method_class.h create mode 100644 paddle/ap/include/code_gen/value.h create mode 100644 paddle/ap/include/code_gen/value_method_class.h create mode 100644 paddle/ap/include/code_module/adt.h create mode 100644 paddle/ap/include/code_module/api_wrapper_project_maker.h create mode 100644 paddle/ap/include/code_module/arg_type.h create mode 100644 paddle/ap/include/code_module/builtin_frame_util.h create mode 100644 paddle/ap/include/code_module/code_module.h create mode 100644 paddle/ap/include/code_module/code_module_method_class.h create mode 100644 paddle/ap/include/code_module/data_type.h create mode 100644 paddle/ap/include/code_module/directory.h create mode 100644 paddle/ap/include/code_module/directory_method_class.h create mode 100644 paddle/ap/include/code_module/file.h create mode 100644 paddle/ap/include/code_module/file_content.h create mode 100644 paddle/ap/include/code_module/file_content_method_class.h create mode 100644 paddle/ap/include/code_module/func_declare.h create mode 100644 paddle/ap/include/code_module/func_declare_method_class.h create mode 100644 paddle/ap/include/code_module/module_compile_helper.h create mode 100644 paddle/ap/include/code_module/module_to_axpr_helper.h create mode 100644 paddle/ap/include/code_module/package.h create mode 100644 paddle/ap/include/code_module/package_method_class.h create mode 100644 paddle/ap/include/code_module/project.h create mode 100644 paddle/ap/include/code_module/project_compile_helper.h create mode 100644 paddle/ap/include/code_module/project_method_class.h create mode 100644 paddle/ap/include/code_module/rt_module.h create mode 100644 paddle/ap/include/code_module/soft_link.h create mode 100644 paddle/ap/include/code_module/soft_link_method_class.h create mode 100644 paddle/ap/include/code_module/source_code.h create mode 100644 paddle/ap/include/code_module/value.h create mode 100644 paddle/ap/include/code_module/value_method_class.h create mode 100644 paddle/ap/include/common/unique_id.h create mode 100644 paddle/ap/include/drr/builtin_frame_util.h create mode 100644 paddle/ap/include/drr/drr_ctx.h create mode 100644 paddle/ap/include/drr/drr_ctx_method_class.h create mode 100644 paddle/ap/include/drr/drr_graph_descriptor.h create mode 100644 paddle/ap/include/drr/drr_interpreter.h create mode 100644 paddle/ap/include/drr/drr_node_descriptor.h create mode 100644 paddle/ap/include/drr/drr_pass_type.h create mode 100644 paddle/ap/include/drr/drr_pass_type_helper.h create mode 100644 paddle/ap/include/drr/drr_value.h create mode 100644 paddle/ap/include/drr/drr_value_helper.h create mode 100644 paddle/ap/include/drr/ir_op.h create mode 100644 paddle/ap/include/drr/ir_value.h create mode 100644 paddle/ap/include/drr/native_ir_op.h create mode 100644 paddle/ap/include/drr/native_ir_op_declare.h create mode 100644 paddle/ap/include/drr/native_ir_op_declare_method_class.h create mode 100644 paddle/ap/include/drr/native_ir_op_method_class.h create mode 100644 paddle/ap/include/drr/native_ir_op_operand.h create mode 100644 paddle/ap/include/drr/native_ir_op_result.h create mode 100644 paddle/ap/include/drr/native_ir_value.h create mode 100644 paddle/ap/include/drr/native_ir_value_method_class.h create mode 100644 paddle/ap/include/drr/node.h create mode 100644 paddle/ap/include/drr/op_pattern_ctx.h create mode 100644 paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h create mode 100644 paddle/ap/include/drr/opt_packed_ir_op.h create mode 100644 paddle/ap/include/drr/opt_packed_ir_op_declare.h create mode 100644 paddle/ap/include/drr/opt_packed_ir_op_declare_method_class.h create mode 100644 paddle/ap/include/drr/opt_packed_ir_op_method_class.h create mode 100644 paddle/ap/include/drr/opt_packed_ir_op_operand.h create mode 100644 paddle/ap/include/drr/opt_packed_ir_op_result.h create mode 100644 paddle/ap/include/drr/packed_ir_op.h create mode 100644 paddle/ap/include/drr/packed_ir_op_declare.h create mode 100644 paddle/ap/include/drr/packed_ir_op_declare_data.h create mode 100644 paddle/ap/include/drr/packed_ir_op_declare_method_class.h create mode 100644 paddle/ap/include/drr/packed_ir_op_method_class.h create mode 100644 paddle/ap/include/drr/packed_ir_op_operand.h create mode 100644 paddle/ap/include/drr/packed_ir_op_result.h create mode 100644 paddle/ap/include/drr/packed_ir_value.h create mode 100644 paddle/ap/include/drr/packed_ir_value_method_class.h create mode 100644 paddle/ap/include/drr/res_ptn_op_pattern_ctx_method_class.h create mode 100644 paddle/ap/include/drr/res_ptn_packed_ir_op_declare_data.h create mode 100644 paddle/ap/include/drr/res_ptn_tensor_pattern_ctx_method_class.h create mode 100644 paddle/ap/include/drr/res_ptn_unbound_native_ir_op_method_class.h create mode 100644 paddle/ap/include/drr/res_ptn_unbound_packed_ir_op_method_class.h create mode 100644 paddle/ap/include/drr/res_ptn_valid_out_ir_value.h create mode 100644 paddle/ap/include/drr/result_pattern_ctx.h create mode 100644 paddle/ap/include/drr/result_pattern_ctx_method_class.h create mode 100644 paddle/ap/include/drr/result_pattern_helper.h create mode 100644 paddle/ap/include/drr/source_pattern_ctx.h create mode 100644 paddle/ap/include/drr/source_pattern_ctx_method_class.h create mode 100644 paddle/ap/include/drr/src_ptn_op_pattern_ctx_method_class.h create mode 100644 paddle/ap/include/drr/src_ptn_packed_ir_op_declare_data.h create mode 100644 paddle/ap/include/drr/src_ptn_tensor_pattern_ctx_method_class.h create mode 100644 paddle/ap/include/drr/src_ptn_unbound_native_ir_op_method_class.h create mode 100644 paddle/ap/include/drr/src_ptn_unbound_packed_ir_op_method_class.h create mode 100644 paddle/ap/include/drr/src_ptn_valid_in_ir_value.h create mode 100644 paddle/ap/include/drr/src_ptn_valid_out_ir_value.h create mode 100644 paddle/ap/include/drr/tags.h create mode 100644 paddle/ap/include/drr/tensor_pattern_ctx.h create mode 100644 paddle/ap/include/drr/topo_kind.h create mode 100644 paddle/ap/include/drr/type.h create mode 100644 paddle/ap/include/drr/unbound_ir_value.h create mode 100644 paddle/ap/include/drr/unbound_ir_value_method_class.h create mode 100644 paddle/ap/include/drr/unbound_native_ir_op.h create mode 100644 paddle/ap/include/drr/unbound_opt_packed_ir_op.h create mode 100644 paddle/ap/include/drr/unbound_opt_packed_ir_op_method_class.h create mode 100644 paddle/ap/include/drr/unbound_packed_ir_op.h create mode 100644 paddle/ap/include/drr/unbound_packed_ir_value.h create mode 100644 paddle/ap/include/drr/unbound_packed_ir_value_method_class.h create mode 100644 paddle/ap/include/drr/value.h create mode 100644 paddle/ap/include/drr/value_method_class.h create mode 100644 paddle/ap/include/env/ap_path.h create mode 100644 paddle/ap/include/fs/fs.h create mode 100644 paddle/ap/include/graph/adt.h create mode 100644 paddle/ap/include/graph/graph_descriptor.h create mode 100644 paddle/ap/include/graph/graph_helper.h create mode 100644 paddle/ap/include/graph/node.h create mode 100644 paddle/ap/include/graph/node_arena.h create mode 100644 paddle/ap/include/graph/node_descriptor.h create mode 100644 paddle/ap/include/graph/node_list.h create mode 100644 paddle/ap/include/graph/node_topo_cstr.h create mode 100644 paddle/ap/include/graph/tags.h create mode 100644 paddle/ap/include/index_expr/builtin_frame_util.h create mode 100644 paddle/ap/include/index_expr/dim_expr_cuda_code_generator.h create mode 100644 paddle/ap/include/index_expr/index_closure.h create mode 100644 paddle/ap/include/index_expr/index_expr.h create mode 100644 paddle/ap/include/index_expr/index_expr_builtin_functions.h create mode 100644 paddle/ap/include/index_expr/index_expr_interpreter.h create mode 100644 paddle/ap/include/index_expr/index_expr_method_class.h create mode 100644 paddle/ap/include/index_expr/index_expr_util.h create mode 100644 paddle/ap/include/index_expr/index_tuple_expr.h create mode 100644 paddle/ap/include/index_expr/index_tuple_expr_cuda_code_generator.h create mode 100644 paddle/ap/include/index_expr/index_tuple_expr_method_class.h create mode 100644 paddle/ap/include/index_expr/op_index_tuple_expr_signature.h create mode 100644 paddle/ap/include/index_expr/op_index_tuple_expr_signature_method_class.h create mode 100644 paddle/ap/include/index_expr/op_signature.h create mode 100644 paddle/ap/include/index_expr/slice.h create mode 100644 paddle/ap/include/index_expr/slice_method_class.h create mode 100644 paddle/ap/include/index_expr/valid_index_expr_builder.h create mode 100644 paddle/ap/include/index_expr/value.h create mode 100644 paddle/ap/include/index_expr/value_method_class.h create mode 100644 paddle/ap/include/ir_match/graph_match_ctx.h create mode 100644 paddle/ap/include/ir_match/graph_matcher.h create mode 100644 paddle/ap/include/ir_match/ir_match_ctx.h create mode 100644 paddle/ap/include/ir_match/native_or_ref_ir_value.h create mode 100644 paddle/ap/include/ir_match/op_match_ctx.h create mode 100644 paddle/ap/include/ir_match/op_match_ctx_method_class.h create mode 100644 paddle/ap/include/ir_match/ref_match_ctx.h create mode 100644 paddle/ap/include/ir_match/ref_node_info.h create mode 100644 paddle/ap/include/ir_match/tags.h create mode 100644 paddle/ap/include/ir_match/tensor_match_ctx.h create mode 100644 paddle/ap/include/ir_match/tensor_match_ctx_method_class.h create mode 100644 paddle/ap/include/ir_match/topo_match_ctx.h create mode 100644 paddle/ap/include/ir_match/topo_matcher.h create mode 100644 paddle/ap/include/kernel_dispatch/ap_unary_kernel.h create mode 100644 paddle/ap/include/kernel_dispatch/arg_value.h create mode 100644 paddle/ap/include/kernel_dispatch/builtin_frame_util.h create mode 100644 paddle/ap/include/kernel_dispatch/const_tensor.h create mode 100644 paddle/ap/include/kernel_dispatch/const_tensor_method_class.h create mode 100644 paddle/ap/include/kernel_dispatch/device_ctx.h create mode 100644 paddle/ap/include/kernel_dispatch/device_ctx_method_class.h create mode 100644 paddle/ap/include/kernel_dispatch/dispatch_ctx.h create mode 100644 paddle/ap/include/kernel_dispatch/dispatch_ctx_method_class.h create mode 100644 paddle/ap/include/kernel_dispatch/dispatch_raw_ctx.h create mode 100644 paddle/ap/include/kernel_dispatch/mutable_tensor.h create mode 100644 paddle/ap/include/kernel_dispatch/mutable_tensor_method_class.h create mode 100644 paddle/ap/include/kernel_dispatch/typed_buffer.h create mode 100644 paddle/ap/include/kernel_dispatch/value.h create mode 100644 paddle/ap/include/memory/circlable_ref.h create mode 100644 paddle/ap/include/memory/circlable_ref_impl.h create mode 100644 paddle/ap/include/memory/circlable_ref_impl_base.h create mode 100644 paddle/ap/include/memory/circlable_ref_list.h create mode 100644 paddle/ap/include/memory/circlable_ref_list_base.h create mode 100644 paddle/ap/include/memory/guard.h create mode 100644 paddle/ap/include/paddle/builtin_frame_util.h create mode 100644 paddle/ap/include/paddle/const_meta_tensor_ptr.h create mode 100644 paddle/ap/include/paddle/const_meta_tensor_ptr_method_class.h create mode 100644 paddle/ap/include/paddle/const_std_vector_const_meta_tensor_ptr_ptr_method_class.h create mode 100644 paddle/ap/include/paddle/ddim.h create mode 100644 paddle/ap/include/paddle/ddim_method_class.h create mode 100644 paddle/ap/include/paddle/indexed_ir_graph.h create mode 100644 paddle/ap/include/paddle/indexed_ir_graph_util.h create mode 100644 paddle/ap/include/paddle/indexed_ir_node.h create mode 100644 paddle/ap/include/paddle/meta_tensor_ptr.h create mode 100644 paddle/ap/include/paddle/meta_tensor_ptr_method_class.h create mode 100644 paddle/ap/include/paddle/op_cuda_code_gen_impl.h create mode 100644 paddle/ap/include/paddle/pass/ap_drr_helper.h create mode 100644 paddle/ap/include/paddle/pass/ap_kernel_define_helper.h create mode 100644 paddle/ap/include/paddle/pass/ap_lower_fusion_op_pass.h create mode 100644 paddle/ap/include/paddle/pass/ap_registry_helper.h create mode 100644 paddle/ap/include/paddle/pass/ir_helper.h create mode 100644 paddle/ap/include/paddle/pass/ir_helper_method_class.h create mode 100644 paddle/ap/include/paddle/phi/ap_infer_meta_helper.h create mode 100644 paddle/ap/include/paddle/phi/device_ctx.h create mode 100644 paddle/ap/include/paddle/phi/kernel_define_helper.h create mode 100644 paddle/ap/include/paddle/phi/kernel_dispatch_helper.h create mode 100644 paddle/ap/include/paddle/phi/place.h create mode 100644 paddle/ap/include/paddle/phi/place_method_class.h create mode 100644 paddle/ap/include/paddle/phi/scalar_helper.h create mode 100644 paddle/ap/include/paddle/pir/attr_adt_type_id.h create mode 100644 paddle/ap/include/paddle/pir/attribute.h create mode 100644 paddle/ap/include/paddle/pir/attribute_method_class.h create mode 100644 paddle/ap/include/paddle/pir/manual_op.h create mode 100644 paddle/ap/include/paddle/pir/op_dialect.h create mode 100644 paddle/ap/include/paddle/pir/packed_ir_op_inner_source_pattern_helper.h create mode 100644 paddle/ap/include/paddle/pir/pass.h create mode 100644 paddle/ap/include/paddle/pir/pass_manager.h create mode 100644 paddle/ap/include/paddle/pir/pass_manager_method_class.h create mode 100644 paddle/ap/include/paddle/pir/pass_method_class.h create mode 100644 paddle/ap/include/paddle/pir/pir.h create mode 100644 paddle/ap/include/paddle/pir/pir_method_class.h create mode 100644 paddle/ap/include/paddle/pir/pir_node_matched_src_ptn_ctx_helper.h create mode 100644 paddle/ap/include/paddle/pir/pir_to_anf_expr_helper.h create mode 100644 paddle/ap/include/paddle/pir/program.h create mode 100644 paddle/ap/include/paddle/pir/program_method_class.h create mode 100644 paddle/ap/include/paddle/pir/shape_or_data_method_class.h create mode 100644 paddle/ap/include/paddle/pir/type.h create mode 100644 paddle/ap/include/paddle/pir/type_adt_type_id.h create mode 100644 paddle/ap/include/paddle/pir/type_method_class.h create mode 100644 paddle/ap/include/paddle/pir_graph_descriptor.h create mode 100644 paddle/ap/include/paddle/pir_node.h create mode 100644 paddle/ap/include/paddle/pir_node_descriptor.h create mode 100644 paddle/ap/include/paddle/pir_node_helper.h create mode 100644 paddle/ap/include/paddle/pir_node_method_class.h create mode 100644 paddle/ap/include/paddle/pir_util.h create mode 100644 paddle/ap/include/paddle/std_vector_meta_tensor_ptr_ptr_method_class.h create mode 100644 paddle/ap/include/preprocessor/preprocessor.h create mode 100644 paddle/ap/include/registry/abstract_drr_pass_registry_item.h create mode 100644 paddle/ap/include/registry/access_topo_drr_pass_registry_item.h create mode 100644 paddle/ap/include/registry/builtin_frame_util.h create mode 100644 paddle/ap/include/registry/classic_drr_pass_registry_item.h create mode 100644 paddle/ap/include/registry/registry.h create mode 100644 paddle/ap/include/registry/registry_class.h create mode 100644 paddle/ap/include/registry/registry_mgr.h create mode 100644 paddle/ap/include/registry/registry_singleton.h create mode 100644 paddle/ap/include/registry/value.h create mode 100644 paddle/ap/include/reified_drr/drr_node_attr_to_anf_expr_helper.h create mode 100644 paddle/ap/include/reified_drr/matched_src_ptn_ctx_helper.h create mode 100644 paddle/ap/include/reified_drr/reified_drr_pass_dump_helper.h create mode 100644 paddle/ap/include/reified_drr/reified_res_ptn_axpr_maker.h create mode 100644 paddle/ap/include/reified_drr/reified_src_ptn_axpr_maker.h create mode 100644 paddle/ap/include/rt_module/arg_value.h create mode 100644 paddle/ap/include/rt_module/dl_function.h create mode 100644 paddle/ap/include/rt_module/dl_handle.h create mode 100644 paddle/ap/include/rt_module/function.h create mode 100644 paddle/ap/include/rt_module/function_helper.h create mode 100644 paddle/ap/include/rt_module/function_method_class.h create mode 100644 paddle/ap/include/rt_module/module.h create mode 100644 paddle/ap/include/rt_module/naive_dl_handler.h create mode 100644 paddle/ap/include/rt_module/naive_module.h create mode 100644 paddle/ap/include/rt_module/naive_module_maker.h create mode 100644 paddle/ap/src/axpr/anf_expr.cc create mode 100644 paddle/ap/src/axpr/builtin_functions.cc create mode 100644 paddle/ap/src/axpr/core_expr.cc create mode 100644 paddle/ap/src/axpr/exception_method_class.cc create mode 100644 paddle/ap/src/axpr/interpreter.cc create mode 100644 paddle/ap/src/axpr/s_expr.cc create mode 100644 paddle/ap/src/code_gen/code_gen_result_method_class.cc create mode 100644 paddle/ap/src/code_module/code_module_method_class.cc create mode 100644 paddle/ap/src/code_module/directory_method_class.cc create mode 100644 paddle/ap/src/code_module/file_content_method_class.cc create mode 100644 paddle/ap/src/code_module/func_declare_method_class.cc create mode 100644 paddle/ap/src/code_module/package_method_class.cc create mode 100644 paddle/ap/src/code_module/project_method_class.cc create mode 100644 paddle/ap/src/code_module/soft_link_method_class.cc create mode 100644 paddle/ap/src/drr/drr_ctx_method_class.cc create mode 100644 paddle/ap/src/drr/drr_interpreter.cc create mode 100644 paddle/ap/src/drr/native_ir_op_declare_method_class.cc create mode 100644 paddle/ap/src/drr/native_ir_op_method_class.cc create mode 100644 paddle/ap/src/drr/native_ir_value_method_class.cc create mode 100644 paddle/ap/src/drr/opt_packed_ir_op_declare_method_class.cc create mode 100644 paddle/ap/src/drr/opt_packed_ir_op_method_class.cc create mode 100644 paddle/ap/src/drr/packed_ir_op_declare_method_class.cc create mode 100644 paddle/ap/src/drr/packed_ir_op_method_class.cc create mode 100644 paddle/ap/src/drr/packed_ir_value_method_class.cc create mode 100644 paddle/ap/src/drr/res_ptn_op_pattern_ctx_method_class.cc create mode 100644 paddle/ap/src/drr/res_ptn_tensor_pattern_ctx_method_class.cc create mode 100644 paddle/ap/src/drr/res_ptn_unbound_native_ir_op_method_class.cc create mode 100644 paddle/ap/src/drr/res_ptn_unbound_packed_ir_op_method_class.cc create mode 100644 paddle/ap/src/drr/result_pattern_ctx_method_class.cc create mode 100644 paddle/ap/src/drr/source_pattern_ctx_method_class.cc create mode 100644 paddle/ap/src/drr/src_ptn_op_pattern_ctx_method_class.cc create mode 100644 paddle/ap/src/drr/src_ptn_tensor_pattern_ctx_method_class.cc create mode 100644 paddle/ap/src/drr/src_ptn_unbound_native_ir_op_method_class.cc create mode 100644 paddle/ap/src/drr/src_ptn_unbound_packed_ir_op_method_class.cc create mode 100644 paddle/ap/src/drr/unbound_ir_value_method_class.cc create mode 100644 paddle/ap/src/drr/unbound_opt_packed_ir_op_method_class.cc create mode 100644 paddle/ap/src/drr/unbound_packed_ir_value_method_class.cc create mode 100644 paddle/ap/src/index_expr/index_closure.cc create mode 100644 paddle/ap/src/index_expr/index_expr_builtin_functions.cc create mode 100644 paddle/ap/src/index_expr/index_expr_util.cc create mode 100644 paddle/ap/src/index_expr/valid_index_expr_builder.cc create mode 100644 paddle/ap/src/kernel_dispatch/device_ctx_method_class.cc create mode 100644 paddle/ap/src/paddle/pass/ap_drr_helper.cc create mode 100644 paddle/ap/src/paddle/pass/ap_kernel_define_helper.cc create mode 100644 paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc create mode 100644 paddle/ap/src/paddle/pass/ap_registry_helper.cc create mode 100644 paddle/ap/src/paddle/pass/ir_helper_method_class.cc create mode 100644 paddle/ap/src/paddle/pass/op_factory.cc create mode 100644 paddle/ap/src/paddle/pass/op_factory.h create mode 100644 paddle/ap/src/paddle/phi/ap_infer_meta_helper.cc create mode 100644 paddle/ap/src/paddle/phi/ap_unary_kernel.cc create mode 100644 paddle/ap/src/paddle/phi/kernel_define_helper.cc create mode 100644 paddle/ap/src/paddle/phi/kernel_dispatch_helper.cc create mode 100644 paddle/ap/src/paddle/pir/attribute_method_class.cc create mode 100644 paddle/ap/src/paddle/pir/manual_op.cc create mode 100644 paddle/ap/src/paddle/pir/op_dialect.cc create mode 100644 paddle/ap/src/paddle/pir/packed_ir_op_inner_source_pattern_helper.cc create mode 100644 paddle/ap/src/paddle/pir/pass_manager_method_class.cc create mode 100644 paddle/ap/src/paddle/pir/pass_method_class.cc create mode 100644 paddle/ap/src/paddle/pir/pir_method_class.cc create mode 100644 paddle/ap/src/paddle/pir/pir_node_matched_src_ptn_ctx_helper.cc create mode 100644 paddle/ap/src/paddle/pir/pir_to_anf_expr_helper.cc create mode 100644 paddle/ap/src/paddle/pir/program_method_class.cc create mode 100644 paddle/ap/src/paddle/pir/shape_or_data_method_class.cc create mode 100644 paddle/ap/src/paddle/pir/type_method_class.cc create mode 100644 paddle/ap/src/reified_drr/reified_drr_pass_dump_helper.cc create mode 100644 paddle/ap/src/reified_drr/reified_res_ptn_axpr_maker.cc create mode 100644 paddle/ap/src/reified_drr/reified_src_ptn_axpr_maker.cc create mode 100644 paddle/phi/kernels/gpu/ap_unary.cu diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 5ba7f9ce3d3794..c41a49d3dd9682 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -5,6 +5,7 @@ set(PYTHON_TESTS_DIR add_subdirectory(utils) add_subdirectory(common) add_subdirectory(pir) +add_subdirectory(ap) add_subdirectory(scripts) add_subdirectory(testing) add_subdirectory(phi) diff --git a/paddle/ap/CMakeLists.txt b/paddle/ap/CMakeLists.txt new file mode 100644 index 00000000000000..828ad2ee0434b3 --- /dev/null +++ b/paddle/ap/CMakeLists.txt @@ -0,0 +1,76 @@ +file(GLOB_RECURSE axpr_srcs "src/axpr/*.cc") +set(axpr_deps common) +cc_library( + axpr + SRCS ${axpr_srcs} + DEPS ${axpr_deps}) + +file(GLOB_RECURSE index_expr_srcs "src/index_expr/*.cc") +set(index_expr_deps axpr) +cc_library( + index_expr + SRCS ${index_expr_srcs} + DEPS ${index_expr_deps}) + +file(GLOB_RECURSE ap_drr_srcs "src/drr/*.cc") +set(ap_drr_deps axpr) +cc_library( + ap_drr + SRCS ${ap_drr_srcs} + DEPS ${ap_drr_deps}) + +file(GLOB_RECURSE ap_code_module_srcs "src/code_module/*.cc") +set(ap_code_module_deps axpr) +cc_library( + ap_code_module + SRCS ${ap_code_module_srcs} + DEPS ${ap_code_module_deps}) + +file(GLOB_RECURSE ap_code_gen_srcs "src/code_gen/*.cc") +set(ap_code_gen_deps axpr ap_code_module) +cc_library( + ap_code_gen + SRCS ${ap_code_gen_srcs} + DEPS ${ap_code_gen_deps}) + +file(GLOB_RECURSE ap_kernel_dispatch_srcs "src/kernel_dispatch/*.cc") +set(ap_kernel_dispatch_deps axpr ap_code_module ap_code_gen) +cc_library( + ap_kernel_dispatch + SRCS ${ap_kernel_dispatch_srcs} + DEPS ${ap_kernel_dispatch_deps}) + +file(GLOB_RECURSE ap_phi_srcs "src/paddle/phi/*.cc") +set(ap_phi_deps axpr ap_code_module ap_code_gen ap_kernel_dispatch) +cc_library( + ap_phi + SRCS ${ap_phi_srcs} + DEPS ${ap_phi_deps}) + +file(GLOB_RECURSE ap_pir_srcs "src/paddle/pir/*.cc") +set(ap_pir_deps axpr ap_drr) +cc_library( + ap_pir + SRCS ${ap_pir_srcs} + DEPS ${ap_pir_deps}) + +file(GLOB_RECURSE ap_reified_drr_srcs "src/reified_drr/*.cc") +set(ap_reified_drr_deps axpr ap_drr ap_code_module ap_code_gen) +cc_library( + ap_reified_drr + SRCS ${ap_reified_drr_srcs} + DEPS ${ap_reified_drr_deps}) + +file(GLOB_RECURSE ap_pass_srcs "src/paddle/pass/*.cc") +set(ap_pass_deps + axpr + ap_pir + index_expr + ap_drr + ap_code_module + ap_code_gen + ap_reified_drr) +cc_library( + ap_pass + SRCS ${ap_pass_srcs} + DEPS ${ap_pass_deps}) diff --git a/paddle/ap/include/adt/adt.h b/paddle/ap/include/adt/adt.h new file mode 100644 index 00000000000000..70c0d029f9516d --- /dev/null +++ b/paddle/ap/include/adt/adt.h @@ -0,0 +1,566 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "glog/logging.h" +#include "paddle/common/overloaded.h" + +namespace ap::adt { + +template +struct Rc { + public: + Rc() : data_(std::make_shared()) {} + explicit Rc(const std::shared_ptr& data) : data_(data) {} + Rc(const Rc&) = default; + Rc(Rc&&) = default; + Rc& operator=(const Rc&) = default; + Rc& operator=(Rc&&) = default; + + template , Rc> && + !std::is_same_v, std::shared_ptr>, + bool> = true> + explicit Rc(Arg&& arg) : data_(new T{std::forward(arg)}) {} + + template + explicit Rc(Arg0&& arg0, Arg1&& arg1, Args&&... args) + : data_(new T{std::forward(arg0), + std::forward(arg1), + std::forward(args)...}) {} + + T* operator->() { return data_.get(); } + const T* operator->() const { return data_.get(); } + + T& operator*() { return *data_; } + const T& operator*() const { return *data_; } + + bool operator==(const Rc& other) const { + if (other.data_.get() == this->data_.get()) { + return true; + } + return *other.data_ == *this->data_; + } + + const std::shared_ptr& shared_ptr() const { return data_; } + + const void* __adt_rc_shared_ptr_raw_ptr() const { return data_.get(); } + + private: + std::shared_ptr data_; +}; + +#define ADT_DEFINE_RC(class_name, ...) \ + struct class_name : public ::ap::adt::Rc<__VA_ARGS__> { \ + using ::ap::adt::Rc<__VA_ARGS__>::Rc; \ + }; + +#define ADT_DEFINE_VARIANT_METHODS(...) \ + ADT_DEFINE_VARIANT_METHODS_WITHOUT_TRYGET(__VA_ARGS__) \ + template \ + static constexpr bool IsMyAlternative() { \ + if constexpr (start_idx >= std::variant_size_v<__VA_ARGS__>) { \ + return false; \ + } else { \ + using AlternativeT = \ + typename std::variant_alternative_t; \ + if constexpr (std::is_same_v) { \ + return true; \ + } else { \ + return IsMyAlternative<__AlternativeT, start_idx + 1>(); \ + } \ + } \ + } \ + template \ + ::ap::adt::Result<__ADT_T> TryGet() const { \ + ADT_CHECK(this->template Has<__ADT_T>()); \ + return this->template Get<__ADT_T>(); \ + } + +#define ADT_DEFINE_VARIANT_METHODS_WITHOUT_TRYGET(...) \ + DEFINE_MATCH_METHOD(); \ + const __VA_ARGS__& variant() const { \ + return reinterpret_cast(*this); \ + } \ + template \ + bool Has() const { \ + return std::holds_alternative<__ADT_T>(variant()); \ + } \ + template \ + const __ADT_T& Get() const { \ + return std::get<__ADT_T>(variant()); \ + } \ + bool operator!=(const __VA_ARGS__& other) const { \ + return !(*this == other); \ + } \ + bool operator==(const __VA_ARGS__& other) const { \ + return std::visit( \ + [](const auto& lhs, const auto& rhs) { \ + if constexpr (std::is_same_v, \ + std::decay_t>) { \ + return lhs == rhs; \ + } else { \ + return false; \ + } \ + }, \ + this->variant(), \ + other); \ + } + +template +class List final { + public: + List(const List&) = default; + List(List&&) = default; + List& operator=(const List&) = default; + List& operator=(List&&) = default; + + using value_type = T; + + explicit List() : vector_(std::make_shared>()) {} + + template < + typename Arg, + std::enable_if_t, List>, bool> = true> + explicit List(Arg&& arg) + : vector_(std::make_shared>( + std::vector{std::forward(arg)})) {} + + template + List(Arg0&& arg0, Arg1&& arg1, Args&&... args) + : vector_(std::make_shared>( + std::vector{std::forward(arg0), + std::forward(arg1), + std::forward(args)...})) {} + + bool operator==(const List& other) const { + if (&vector() == &other.vector()) { + return true; + } + return vector() == other.vector(); + } + + bool operator!=(const List& other) const { return !(*this == other); } + + std::vector& operator*() const { return *vector_; } + std::vector* operator->() const { return vector_.get(); } + + const std::vector& vector() const { return *vector_; } + + const auto& Get(std::size_t idx) const { return vector_->at(idx); } + + private: + std::shared_ptr> vector_; +}; + +#define ADT_DEFINE_TAG(TagName) \ + template \ + class TagName { \ + public: \ + TagName() = default; \ + TagName(const TagName&) = default; \ + TagName(TagName&&) = default; \ + TagName& operator=(const TagName&) = default; \ + TagName& operator=(TagName&&) = default; \ + \ + bool operator==(const TagName& other) const { \ + return value_ == other.value(); \ + } \ + \ + bool operator!=(const TagName& other) const { \ + return value_ != other.value(); \ + } \ + \ + template , TagName>, \ + bool> = true> \ + explicit TagName(Arg&& value) : value_(value) {} \ + \ + const T& value() const { return value_; } \ + \ + private: \ + T value_; \ + }; + +// Undefined = {} +struct Undefined final : public std::monostate { + using std::monostate::monostate; +}; + +// Ok = {} +struct Ok final : public std::monostate { + using std::monostate::monostate; +}; + +inline std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { + return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); +} + +struct Nothing : public std::monostate { + using std::monostate::monostate; +}; + +struct IdentityFunc : public std::monostate { + using std::monostate::monostate; +}; + +template +using EitherImpl = std::variant; + +template +struct Either : public EitherImpl { + using EitherImpl::EitherImpl; + ADT_DEFINE_VARIANT_METHODS_WITHOUT_TRYGET(EitherImpl); +}; + +template +struct Maybe : public Either { + using Either::Either; +}; + +namespace source_code { + +struct CodeLocation { + std::string file_name; + int line_no; + std::string func_name; + std::string code; + + bool operator==(const CodeLocation& other) const { + return this->file_name == other.file_name && + this->line_no == other.line_no && + this->func_name == other.func_name && this->code == other.code; + } +}; + +using CallStack = std::list; + +} // namespace source_code + +namespace errors { + +struct RuntimeError { + std::string msg; + source_code::CallStack call_stack{}; + + bool operator==(const RuntimeError& other) const { + return this->msg == other.msg && this->call_stack == other.call_stack; + } + + const char* class_name() const { return "RuntimeError"; } +}; + +struct InvalidArgumentError { + std::string msg; + source_code::CallStack call_stack{}; + + bool operator==(const InvalidArgumentError& other) const { + return this->msg == other.msg && this->call_stack == other.call_stack; + } + + const char* class_name() const { return "InvalidArgumentError"; } +}; + +struct AttributeError { + std::string msg; + source_code::CallStack call_stack{}; + + bool operator==(const AttributeError& other) const { + return this->msg == other.msg && this->call_stack == other.call_stack; + } + + const char* class_name() const { return "AttributeError"; } +}; + +struct NameError { + std::string msg; + source_code::CallStack call_stack{}; + + bool operator==(const NameError& other) const { + return this->msg == other.msg && this->call_stack == other.call_stack; + } + + const char* class_name() const { return "NameError"; } +}; + +struct ValueError { + std::string msg; + source_code::CallStack call_stack{}; + + bool operator==(const ValueError& other) const { + return this->msg == other.msg && this->call_stack == other.call_stack; + } + + const char* class_name() const { return "ValueError"; } +}; + +struct ZeroDivisionError { + std::string msg; + source_code::CallStack call_stack{}; + + bool operator==(const ZeroDivisionError& other) const { + return this->msg == other.msg && this->call_stack == other.call_stack; + } + + const char* class_name() const { return "ZeroDivisionError"; } +}; + +struct TypeError { + std::string msg; + source_code::CallStack call_stack{}; + + bool operator==(const TypeError& other) const { + return this->msg == other.msg && this->call_stack == other.call_stack; + } + + const char* class_name() const { return "TypeError"; } +}; + +struct IndexError { + std::string msg; + source_code::CallStack call_stack{}; + + bool operator==(const IndexError& other) const { + return this->msg == other.msg && this->call_stack == other.call_stack; + } + + const char* class_name() const { return "IndexError"; } +}; + +struct KeyError { + std::string msg; + source_code::CallStack call_stack{}; + + bool operator==(const KeyError& other) const { + return this->msg == other.msg && this->call_stack == other.call_stack; + } + + const char* class_name() const { return "KeyError"; } +}; + +struct MismatchError { + std::string msg; + source_code::CallStack call_stack{}; + + bool operator==(const MismatchError& other) const { + return this->msg == other.msg && this->call_stack == other.call_stack; + } + + const char* class_name() const { return "MismatchError"; } +}; + +struct NotImplementedError { + std::string msg; + source_code::CallStack call_stack{}; + + bool operator==(const NotImplementedError& other) const { + return this->msg == other.msg && this->call_stack == other.call_stack; + } + + const char* class_name() const { return "NotImplementedError"; } +}; + +struct SyntaxError { + std::string msg; + source_code::CallStack call_stack{}; + + bool operator==(const SyntaxError& other) const { + return this->msg == other.msg && this->call_stack == other.call_stack; + } + + const char* class_name() const { return "SyntaxError"; } +}; + +struct ModuleNotFoundError { + std::string msg; + source_code::CallStack call_stack{}; + + bool operator==(const ModuleNotFoundError& other) const { + return this->msg == other.msg && this->call_stack == other.call_stack; + } + + const char* class_name() const { return "ModuleNotFoundError"; } +}; + +struct AssertionError { + std::string msg; + source_code::CallStack call_stack{}; + + bool operator==(const AssertionError& other) const { + return this->msg == other.msg && this->call_stack == other.call_stack; + } + + const char* class_name() const { return "AssertionError"; } +}; + +using ErrorBase = std::variant; + +struct [[nodiscard]] Error : public ErrorBase { + using ErrorBase::ErrorBase; + ADT_DEFINE_VARIANT_METHODS_WITHOUT_TRYGET(ErrorBase); + + const char* class_name() const { + return Match([](const auto& impl) { return impl.class_name(); }); + } + + const std::string& msg() const { + return Match( + [](const auto& impl) -> const std::string& { return impl.msg; }); + } + + const source_code::CallStack& call_stack() const { + return Match([](const auto& impl) -> const source_code::CallStack& { + return impl.call_stack; + }); + } + + std::string CallStackToString() const { + std::ostringstream ss; + for (const auto* code_location : call_stack()) { + ss << " File \"" << code_location->file_name << "\", line " + << code_location->line_no << ", in " << code_location->func_name + << "\n " << code_location->code << "\n"; + } + return ss.str(); + } + + Error operator<<(Error&& replacement) const { + if (this->call_stack().size() > 0) { + replacement.mut_call_stack()->push_front(*this->call_stack().begin()); + } + return std::move(replacement); + } + + Error operator<<(const Error& replacement) const { + if (this->call_stack().size() > 0) { + replacement.mut_call_stack()->push_front(*this->call_stack().begin()); + } + return replacement; + } + + Error operator<<( + const std::function& GetReplacement) const { + const auto& replacement = GetReplacement(*this); + return (*this) << replacement; + } + + Error operator<<(const source_code::CodeLocation* code_location) const { + mut_call_stack()->push_front(code_location); + return *this; + } + + private: + source_code::CallStack* mut_call_stack() const { + return const_cast(&call_stack()); + } +}; + +} // namespace errors + +template +struct [[nodiscard]] Result : public Either { + using Either::Either; + + bool HasError() const { return this->template Has(); } + + bool HasOkValue() const { return !HasError(); } + + const errors::Error& GetError() const { + return this->template Get(); + } + + const T& GetOkValue() const { return this->template Get(); } +}; + +struct Break : public std::monostate { + using std::monostate::monostate; +}; + +struct Continue : public std::monostate { + using std::monostate::monostate; +}; + +using LoopCtrlImpl = std::variant; + +struct LoopCtrl : public LoopCtrlImpl { + using LoopCtrlImpl::LoopCtrlImpl; + + ADT_DEFINE_VARIANT_METHODS(LoopCtrlImpl); +}; + +template +adt::Result> WeakPtrLock(const std::weak_ptr& weak_ptr) { + const auto& ptr = weak_ptr.lock(); + if (!ptr) { + return errors::RuntimeError{"weak_ptr.lock() failed."}; + } + return ptr; +} + +#define ADT_CURRENT_CODE_LOCATION(filename, line_no, func_name, code) \ + ([] { \ + static const ::ap::adt::source_code::CodeLocation loc{ \ + filename, line_no, func_name, code}; \ + return &loc; \ + }()) + +// clang-format off +#define ADT_CHECK(...) /* NOLINT */ \ + if (!(__VA_ARGS__)) /* NOLINT */ \ + return ::ap::adt::errors::Error{::ap::adt::errors::ValueError{ /* NOLINT */ \ + "Check '" #__VA_ARGS__ "' failed." /* NOLINT */ \ + }} << ADT_CURRENT_CODE_LOCATION( /* NOLINT */ \ + __FILE__, __LINE__, __FUNCTION__, #__VA_ARGS__ /* NOLINT */ \ + ) +// clang-format on + +#define ADT_RETURN_IF_ERR(...) \ + if (const auto& __result##__LINE__ = __VA_ARGS__; \ + __result##__LINE__.HasError()) \ + return __result##__LINE__.GetError() << ADT_CURRENT_CODE_LOCATION( \ + __FILE__, __LINE__, __FUNCTION__, #__VA_ARGS__) + +#define ADT_LET_CONST_REF(var, ...) \ + const auto& __result_##var = __VA_ARGS__; \ + const auto* __ptr_##var = \ + (__result_##var.HasError() ? nullptr : &__result_##var.GetOkValue()); \ + const auto& var = *__ptr_##var; \ + if (__result_##var.HasError()) \ + return __result_##var.GetError() << ADT_CURRENT_CODE_LOCATION( \ + __FILE__, __LINE__, __FUNCTION__, #__VA_ARGS__) + +} // namespace ap::adt diff --git a/paddle/ap/include/adt/bfs_walker.h b/paddle/ap/include/adt/bfs_walker.h new file mode 100644 index 00000000000000..17858f997fdffd --- /dev/null +++ b/paddle/ap/include/adt/bfs_walker.h @@ -0,0 +1,73 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/ap/include/adt/adt.h" + +namespace ap::adt { + +// breadth-first search visitor +template +class BfsWalker final { + public: + BfsWalker(const BfsWalker&) = delete; + BfsWalker(BfsWalker&&) = delete; + + using NodeHandlerType = std::function(NodeType)>; + using NodesVisitorType = + std::function(NodeType, const NodeHandlerType&)>; + + BfsWalker(const NodesVisitorType& VisitNextNodesVal) + : VisitNextNodes(VisitNextNodesVal) {} + + adt::Result operator()(NodeType node, + const NodeHandlerType& NodeHandler) const { + std::array nodes{node}; + return (*this)(nodes.begin(), nodes.end(), NodeHandler); + } + + template + adt::Result operator()(NodeIt begin, + NodeIt end, + const NodeHandlerType& NodeHandler) const { + std::queue node_queue; + std::unordered_set queued_nodes; + const auto& TryEnqueueNode = [&](NodeType node) -> adt::Result { + if (queued_nodes.count(node) == 0) { + node_queue.push(node); + queued_nodes.insert(node); + } + return adt::Ok{}; + }; + for (NodeIt iter = begin; iter != end; ++iter) { + ADT_RETURN_IF_ERR(TryEnqueueNode(*iter)); + } + while (!node_queue.empty()) { + NodeType node = node_queue.front(); + node_queue.pop(); + ADT_RETURN_IF_ERR(NodeHandler(node)); + ADT_RETURN_IF_ERR(VisitNextNodes(node, TryEnqueueNode)); + } + return adt::Ok{}; + } + + NodesVisitorType VisitNextNodes; +}; + +} // namespace ap::adt diff --git a/paddle/ap/include/adt/topo_walker.h b/paddle/ap/include/adt/topo_walker.h new file mode 100644 index 00000000000000..f04b79c3f62b67 --- /dev/null +++ b/paddle/ap/include/adt/topo_walker.h @@ -0,0 +1,90 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/ap/include/adt/adt.h" + +namespace ap::adt { + +// Topological order visitor +template +class TopoWalker final { + public: + TopoWalker(const TopoWalker&) = default; + TopoWalker(TopoWalker&&) = default; + + using RetT = adt::Result; + + using NodeHandlerType = std::function; + using NodesVisitorType = + std::function; + + TopoWalker(const NodesVisitorType& VisitPrevNodesValue, + const NodesVisitorType& VisitNextNodesValue) + : VisitPrevNodes(VisitPrevNodesValue), + VisitNextNodes(VisitNextNodesValue) {} + + RetT operator()(NodeType node, const NodeHandlerType& NodeHandler) const { + std::array nodes{node}; + return (*this)(nodes.begin(), nodes.end(), NodeHandler); + } + + template + RetT operator()(NodeIt begin, + NodeIt end, + const NodeHandlerType& NodeHandler) const { + std::queue node_queue; + std::unordered_set queued_nodes; + const auto& TryEnqueueNode = [&](NodeType node) { + if (queued_nodes.count(node) == 0) { + node_queue.push(node); + queued_nodes.insert(node); + } + }; + for (NodeIt iter = begin; iter != end; ++iter) { + TryEnqueueNode(*iter); + } + while (!node_queue.empty()) { + NodeType node = node_queue.front(); + node_queue.pop(); + ADT_RETURN_IF_ERR(NodeHandler(node)); + ADT_RETURN_IF_ERR(VisitNextNodes(node, [&](NodeType node) -> RetT { + size_t num_unfinished_inputs = 0; + ADT_RETURN_IF_ERR(VisitPrevNodes(node, [&](NodeType in_node) -> RetT { + num_unfinished_inputs += (queued_nodes.count(in_node) > 0 ? 0 : 1); + return adt::Ok{}; + })); + if (num_unfinished_inputs == 0) { + TryEnqueueNode(node); + } + return adt::Ok{}; + })); + } + return adt::Ok{}; + } + + TopoWalker GetReversed() const { + return TopoWalker(this->VisitNextNodes, this->VisitPrevNodes); + } + + NodesVisitorType VisitPrevNodes; + NodesVisitorType VisitNextNodes; +}; + +} // namespace ap::adt diff --git a/paddle/ap/include/axpr/abstract_list.h b/paddle/ap/include/axpr/abstract_list.h new file mode 100644 index 00000000000000..cf83cd43fd181d --- /dev/null +++ b/paddle/ap/include/axpr/abstract_list.h @@ -0,0 +1,119 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/list.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/mutable_list.h" +#include "paddle/ap/include/axpr/serializable_value.h" + +namespace ap::axpr { + +template +using AbstractListImpl = std::variant, + adt::List, + axpr::MutableList>; + +template +struct AbstractList : public AbstractListImpl { + using AbstractListImpl::AbstractListImpl; + + ADT_DEFINE_VARIANT_METHODS(AbstractListImpl); + + static adt::Result> CastFrom(const ValueT& value) { + using RetT = adt::Result>; + return value.Match( + [&](const adt::List& impl) -> RetT { return impl; }, + [&](const adt::List& impl) -> RetT { return impl; }, + [&](const axpr::MutableList& impl) -> RetT { return impl; }, + [&](const auto&) -> RetT { + return adt::errors::TypeError{ + std::string() + + "only list, SerializableList, MutableList are convertible to " + "AbstractList. (" + + GetTypeName(value) + " given)"}; + }); + } + + static bool CastableFrom(const ValueT& value) { + using RetT = bool; + return value.Match( + [&](const adt::List& impl) -> RetT { return true; }, + [&](const adt::List& impl) -> RetT { return true; }, + [&](const axpr::MutableList& impl) -> RetT { return true; }, + [&](const auto&) -> RetT { return false; }); + } + + adt::Result size() const { + using RetT = adt::Result; + return Match( + [](const axpr::MutableList& impl) -> RetT { + ADT_LET_CONST_REF(data_vec, impl.Get()); + return data_vec->size(); + }, + [](const auto& impl) -> RetT { return impl->size(); }); + } + + adt::Result at(std::size_t i) const { + using RetT = adt::Result; + return Match( + [&](const adt::List& impl) -> RetT { return impl->at(i); }, + [&](const adt::List& impl) -> RetT { + return impl->at(i).template CastTo(); + }, + [&](const axpr::MutableList& impl) -> RetT { + ADT_LET_CONST_REF(data_vec, impl.Get()); + return data_vec->at(i); + }); + } + + template + adt::Result Visit(const DoEachT& DoEach) const { + using Ok = adt::Result; + return Match( + [&](const adt::List& impl) -> Ok { + for (const auto& elt : *impl) { + ADT_LET_CONST_REF(loop_ctrl, DoEach(elt)); + if (loop_ctrl.template Has()) { + break; + } + } + return adt::Ok{}; + }, + [&](const adt::List& impl) -> Ok { + for (const auto& serializable_elt : *impl) { + const auto& elt = serializable_elt.template CastTo(); + ADT_LET_CONST_REF(loop_ctrl, DoEach(elt)); + if (loop_ctrl.template Has()) { + break; + } + } + return adt::Ok{}; + }, + [&](const axpr::MutableList& impl) -> Ok { + ADT_LET_CONST_REF(vec, impl.Get()); + for (const auto& elt : *vec) { + ADT_LET_CONST_REF(loop_ctrl, DoEach(elt)); + if (loop_ctrl.template Has()) { + break; + } + } + return adt::Ok{}; + }); + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/adt.h b/paddle/ap/include/axpr/adt.h new file mode 100644 index 00000000000000..7131d5ecfc2a4b --- /dev/null +++ b/paddle/ap/include/axpr/adt.h @@ -0,0 +1,29 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/ap/include/adt/adt.h" + +namespace ap::axpr { + +ADT_DEFINE_TAG(tVar); + +using adt::Nothing; +using adt::Result; + +template +using Maybe = adt::Maybe; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/anf_expr.h b/paddle/ap/include/axpr/anf_expr.h new file mode 100644 index 00000000000000..ffa14579126908 --- /dev/null +++ b/paddle/ap/include/axpr/anf_expr.h @@ -0,0 +1,94 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/atomic.h" + +namespace ap::axpr { + +template +struct IfImpl { + Atomic cond; + Expr true_expr; + Expr false_expr; + + bool operator==(const IfImpl& other) const { + return (this->cond == other.cond) && + (this->true_expr == other.false_expr) && + (this->false_expr == other.false_expr); + } +}; + +template +ADT_DEFINE_RC(If, const IfImpl); + +template +using CombinedBase = std::variant, If>; + +template +struct Combined : public CombinedBase { + using CombinedBase::CombinedBase; + ADT_DEFINE_VARIANT_METHODS(CombinedBase); +}; + +template +struct Bind { + tVar var; + Combined val; + + bool operator==(const Bind& other) const { + return this->var == other.var && this->val == other.val; + } +}; + +template +struct LetImpl { + std::vector> bindings; + Expr body; + + bool operator==(const LetImpl& other) const { + return this->bindings == other.bindings && this->body == other.body; + } +}; + +template +ADT_DEFINE_RC(Let, const LetImpl); + +struct AnfExpr; + +// expr := aexpr | cexpr | let [VAR cexpr] expr +// cexpr := (aexpr aexpr ...) | (If aexpr expr expr) +using AnfExprBase = + std::variant, Combined, Let>; + +// A-norm form +struct AnfExpr : public AnfExprBase { + using AnfExprBase::AnfExprBase; + ADT_DEFINE_VARIANT_METHODS(AnfExprBase); + + static constexpr const char* kString() { return "str"; } + static constexpr const char* kLambda() { return "lambda"; } + static constexpr const char* kIf() { return "if"; } + static constexpr const char* kLet() { return "__builtin_let__"; } + + std::string DumpToJsonString() const; + std::string DumpToJsonString(int indent) const; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/anf_expr_builder.h b/paddle/ap/include/axpr/anf_expr_builder.h new file mode 100644 index 00000000000000..f201b4fbd8fed3 --- /dev/null +++ b/paddle/ap/include/axpr/anf_expr_builder.h @@ -0,0 +1,61 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/anf_expr.h" +#include "paddle/ap/include/axpr/atomic_builder.h" + +namespace ap::axpr { + +class AnfExprBuilder : public AtomicExprBuilder { + public: + AnfExprBuilder() {} + AnfExprBuilder(const AnfExprBuilder&) = delete; + AnfExprBuilder(AnfExprBuilder&&) = delete; + + Combined Call(const Atomic& f, + const std::vector>& args) { + return Combined{ap::axpr::Call{f, args}}; + } + + Combined If(const Atomic& c, + const AnfExpr& t, + const AnfExpr& f) { + return Combined{ap::axpr::If{c, t, f}}; + } + + ap::axpr::Bind Bind(const std::string& var, + const Combined& val) { + return ap::axpr::Bind{tVar{var}, val}; + } + + ap::axpr::Let Let( + const std::vector>& assigns, + const AnfExpr& body) { + return ap::axpr::Let{assigns, body}; + } + + AnfExpr operator()(const Atomic& atomic) { return AnfExpr{atomic}; } + + AnfExpr operator()(const Combined& combined) { + return AnfExpr{combined}; + } + + AnfExpr operator()(const ap::axpr::Let& let) { return AnfExpr{let}; } + + private: +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/anf_expr_helper.h b/paddle/ap/include/axpr/anf_expr_helper.h new file mode 100644 index 00000000000000..19ea269c88a1e8 --- /dev/null +++ b/paddle/ap/include/axpr/anf_expr_helper.h @@ -0,0 +1,214 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/ap/include/axpr/anf_expr.h" + +namespace ap::axpr { + +struct AnfExprHelper { + adt::Result FunctionToString( + const Lambda& lambda) const { + return FunctionToString("unnamed", lambda); + } + + private: + adt::Result FunctionToString( + const std::string& func_name, const Lambda& lambda) const { + std::ostringstream ss; + auto Generate = [&](const std::string& str) { ss << str << "\n"; }; + ADT_RETURN_IF_ERR(SerializeFunction(Generate, func_name, lambda)); + return ss.str(); + } + + adt::Result SerializeFunction( + const std::function& Generate, + const std::string& func_name, + const Lambda& lambda) const { + { + std::ostringstream ss; + ss << "def " << func_name << "("; + int i = 0; + for (const auto& arg : lambda->args) { + if (i++ > 0) { + ss << ", "; + } + ss << arg.value(); + } + ss << "):"; + Generate(ss.str()); + } + { + auto BodyGenerate = [&](const std::string& str) { + Generate(std::string(" ") + str); + }; + ADT_RETURN_IF_ERR(SerializeLastExprInLambda(BodyGenerate, lambda->body)); + } + return adt::Ok{}; + } + + struct LambdaBodySerializeCtx { + std::function Generate; + + std::size_t auto_id_in_body = 0; + + std::size_t GetAutoIdInBody() { return auto_id_in_body++; } + }; + + adt::Result SerializeLastExprInLambda( + const std::function& Generate, + const AnfExpr& lambda_body) const { + LambdaBodySerializeCtx ctx{Generate}; + return SerializeLastExprInLambda(&ctx, lambda_body); + } + + adt::Result SerializeLastExprInLambda( + LambdaBodySerializeCtx* ctx, const AnfExpr& lambda_body) const { + return lambda_body.Match([&](const auto& impl) -> adt::Result { + return SerializeLastExprInLambdaImpl(ctx, impl); + }); + } + + adt::Result SerializeLastExprInLambdaImpl( + LambdaBodySerializeCtx* ctx, const Atomic& atomic) const { + ADT_LET_CONST_REF( + atomic_str, + atomic.Match([&](const auto& impl) -> adt::Result { + return AtomicToStringImpl(ctx, impl); + })); + ctx->Generate(std::string() + "return " + atomic_str); + return adt::Ok{}; + } + + adt::Result AtomicToString(LambdaBodySerializeCtx* ctx, + const Atomic& atomic) const { + return atomic.Match([&](const auto& impl) -> adt::Result { + return AtomicToStringImpl(ctx, impl); + }); + } + + adt::Result SerializeLastExprInLambdaImpl( + LambdaBodySerializeCtx* ctx, const Combined& combined) const { + ADT_LET_CONST_REF( + combined_str, + combined.Match([&](const auto& impl) -> adt::Result { + return CombinedToStringImpl(ctx, impl); + })); + ctx->Generate(std::string() + "return " + combined_str); + return adt::Ok{}; + } + + adt::Result CombinedToString( + LambdaBodySerializeCtx* ctx, const Combined& combined) const { + return combined.Match([&](const auto& impl) -> adt::Result { + return CombinedToStringImpl(ctx, impl); + }); + } + + adt::Result SerializeLastExprInLambdaImpl( + LambdaBodySerializeCtx* ctx, const Let& let) const { + for (const auto& [var, combined] : let->bindings) { + ADT_LET_CONST_REF(combined_str, CombinedToString(ctx, combined)); + ctx->Generate(var.value() + " = " + combined_str); + } + return SerializeLastExprInLambda(ctx, let->body); + } + + adt::Result CombinedToStringImpl( + LambdaBodySerializeCtx* ctx, const Call& call) const { + std::ostringstream ss; + ADT_LET_CONST_REF(func_name, AtomicToString(ctx, call->func)); + ss << func_name << "("; + int i = 0; + for (const auto& arg : call->args) { + if (i++ > 0) { + ss << ", "; + } + ADT_LET_CONST_REF(arg_str, AtomicToString(ctx, arg)); + ss << arg_str; + } + ss << ")"; + return ss.str(); + } + + adt::Result CombinedToStringImpl( + LambdaBodySerializeCtx* ctx, const If& if_expr) const { + std::ostringstream ss; + ADT_LET_CONST_REF(cond, AtomicToString(ctx, if_expr->cond)); + ADT_LET_CONST_REF(true_expr, AnfExprToString(ctx, if_expr->true_expr)); + ADT_LET_CONST_REF(false_expr, AnfExprToString(ctx, if_expr->false_expr)); + ss << true_expr << " if " << cond << " else " << false_expr; + return ss.str(); + } + + adt::Result AnfExprToString(LambdaBodySerializeCtx* ctx, + const AnfExpr& anf_expr) const { + return anf_expr.Match( + [&](const Atomic& atomic) -> adt::Result { + return AtomicToString(ctx, atomic); + }, + [&](const Combined& combined) -> adt::Result { + return CombinedToString(ctx, combined); + }, + [&](const Let&) -> adt::Result { + return adt::errors::TypeError{ + "Let is not supported in AnfExprToString()."}; + }); + } + + adt::Result AtomicToStringImpl(LambdaBodySerializeCtx* ctx, + const adt::Nothing&) const { + return std::string("None"); + } + + adt::Result AtomicToStringImpl(LambdaBodySerializeCtx* ctx, + bool c) const { + return std::string(c ? "True" : "False"); + } + + adt::Result AtomicToStringImpl(LambdaBodySerializeCtx* ctx, + int64_t c) const { + return std::to_string(c); + } + + adt::Result AtomicToStringImpl(LambdaBodySerializeCtx* ctx, + double c) const { + return std::to_string(c); + } + + adt::Result AtomicToStringImpl(LambdaBodySerializeCtx* ctx, + const std::string& str) const { + std::ostringstream ss; + ss << std::quoted(str); + return ss.str(); + } + + adt::Result AtomicToStringImpl( + LambdaBodySerializeCtx* ctx, const tVar& var) const { + return var.value(); + } + + adt::Result AtomicToStringImpl( + LambdaBodySerializeCtx* ctx, const Lambda& lambda) const { + const auto& func_name = + std::string("tmp_func_") + std::to_string(ctx->GetAutoIdInBody()); + ADT_RETURN_IF_ERR(SerializeFunction(ctx->Generate, func_name, lambda)); + return func_name; + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/anf_expr_util.h b/paddle/ap/include/axpr/anf_expr_util.h new file mode 100644 index 00000000000000..5193d9af2c31c9 --- /dev/null +++ b/paddle/ap/include/axpr/anf_expr_util.h @@ -0,0 +1,760 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "glog/logging.h" +#include "nlohmann/json.hpp" +#include "paddle/ap/include/axpr/anf_expr.h" +#include "paddle/ap/include/axpr/anf_expr_builder.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/core_expr_builder.h" +#include "paddle/common/enforce.h" + +namespace ap::axpr { + +AnfExpr ConvertCoreExprToAnfExpr(const CoreExpr& core_expr); + +namespace detail { + +struct CoreExprToAnfExprConverter { + AnfExpr ConvertCoreExprToAnfExpr(const CoreExpr& core_expr) { + return core_expr.Match( + [&](const Atomic& atomic) -> AnfExpr { + return ConvertAtomic(atomic); + }, + [&](const ComposedCall>& composed_call) -> AnfExpr { + return ConvertComposedCall(composed_call); + }); + } + + private: + Atomic ConvertAtomic(const Atomic& atomic) { + return atomic.Match( + [&](const Lambda& lambda) -> Atomic { + return ConvertLambda(lambda); + }, + [&](const Symbol& symbol) -> Atomic { + return symbol.Match( + [&](const tVar& var) -> Atomic { + return Atomic{var}; + }, + [&](const builtin_symbol::Symbol& symbol) -> Atomic { + tVar var{symbol.Name()}; + return Atomic{var}; + }); + }, + [&](adt::Nothing) -> Atomic { + return Atomic{adt::Nothing{}}; + }, + [&](bool c) -> Atomic { return Atomic{c}; }, + [&](int64_t c) -> Atomic { return Atomic{c}; }, + [&](double c) -> Atomic { return Atomic{c}; }, + [&](const std::string& val) -> Atomic { + return Atomic{val}; + }); + } + + Atomic ConvertLambda(const Lambda& lambda) { + return Lambda{lambda->args, + ConvertCoreExprToAnfExpr(lambda->body)}; + } + + AnfExpr ConvertComposedCall( + const ComposedCall>& composed_call) { + const auto& outer_func = composed_call->outer_func; + return outer_func.Match( + [&](const Lambda& lambda) -> AnfExpr { + std::vector> bindings; + return ConvertComposedCallToLet(composed_call, &bindings); + }, + [&](const Symbol& symbol) -> AnfExpr { + return symbol.Match( + [&](const tVar& var) -> AnfExpr { + CHECK_EQ(var.value(), kBuiltinReturn()); + return ConvertComposedCallToCombined(composed_call); + }, + [&](const builtin_symbol::Symbol& symbol) -> AnfExpr { + LOG(FATAL) << "outer_func should be a lambda or " + << kBuiltinReturn(); + return Atomic{""}; + }); + }, + [&](const auto& c) -> AnfExpr { + LOG(FATAL) << "outer_func should be a lambda or " << kBuiltinReturn(); + return Atomic(c); + }); + } + + Combined ConvertComposedCallToCombined( + const ComposedCall>& composed_call) { + const auto& f = ConvertAtomic(composed_call->inner_func); + std::vector> args; + args.reserve(composed_call->args.size()); + for (const auto& arg : composed_call->args) { + args.push_back(ConvertAtomic(arg)); + } + return Combined{Call{f, std::move(args)}}; + } + + AnfExpr ConvertComposedCallToLet( + const ComposedCall>& composed_call, + std::vector>* bindings) { + const auto& outer_func = composed_call->outer_func; + return outer_func.Match( + [&](const Lambda& lambda) -> AnfExpr { + CHECK_EQ(lambda->args.size(), 1); + const auto& val = ConvertComposedCallToCombined(composed_call); + Bind binding{lambda->args.at(0), val}; + bindings->emplace_back(std::move(binding)); + const auto& body = lambda->body; + return body.Match( + [&](const Atomic& atomic_body) -> AnfExpr { + return Let{*bindings, ConvertAtomic(atomic_body)}; + }, + [&](const ComposedCall>& composed_call_body) + -> AnfExpr { + return ConvertComposedCallToLet(composed_call_body, bindings); + }); + }, + [&](const Symbol& symbol) -> AnfExpr { + return symbol.Match( + [&](const tVar& var) -> AnfExpr { + CHECK_EQ(var.value(), kBuiltinReturn()); + const auto& body = ConvertComposedCallToCombined(composed_call); + return Let{*bindings, body}; + }, + [&](const builtin_symbol::Symbol& symbol) -> AnfExpr { + LOG(FATAL) << "outer_func should be a lambda or " + << kBuiltinReturn(); + return Atomic{""}; + }); + }, + [&](const auto& c) -> AnfExpr { + LOG(FATAL) << "outer_func should be a lambda or " << kBuiltinReturn(); + return Atomic(c); + }); + } +}; + +} // namespace detail + +inline AnfExpr ConvertCoreExprToAnfExpr(const CoreExpr& core_expr) { + return detail::CoreExprToAnfExprConverter().ConvertCoreExprToAnfExpr( + core_expr); +} + +namespace detail { + +// Convert anf expr to core expr without duplicate var name. +struct AnfExprToCoreExprConverter { + AnfExprToCoreExprConverter() : core_() {} + + using LazyCoreExpr = std::function( + const Atomic& continuation)>; + + using MaybeLazyCoreExprBase = std::variant; + + struct MaybeLazyCoreExpr : public MaybeLazyCoreExprBase { + using MaybeLazyCoreExprBase::MaybeLazyCoreExprBase; + + DEFINE_MATCH_METHOD(); + + const MaybeLazyCoreExprBase& variant() const { + return reinterpret_cast(*this); + } + + template + bool Has() const { + return std::holds_alternative(variant()); + } + + template + const T& Get() const { + return std::get(variant()); + } + }; + + template + MaybeLazyCoreExpr CoreVal(const T& val) { + return MaybeLazyCoreExpr{CoreExpr{val}}; + } + + MaybeLazyCoreExpr LazyCoreVal(const LazyCoreExpr& lazy) { + return MaybeLazyCoreExpr{lazy}; + } + + using value_type = MaybeLazyCoreExpr; + + CoreExpr ConvertAnfExprToCoreExpr(const AnfExpr& anf_expr) { + MaybeLazyCoreExpr ret_val = Convert(anf_expr); + const auto& lazy_core_expr = TryWrapperToLazyCoreExpr(ret_val); + CoreExpr ret = lazy_core_expr(CoreExprBuilder().Var(kBuiltinReturn())); + return ret.Match( + [&](const Atomic&) -> CoreExpr { return ret; }, + [&](const ComposedCallAtomic& composed_call) -> CoreExpr { + Atomic return_id{tVar{kBuiltinReturn()}}; + Atomic identity{Symbol{builtin_symbol::Id{}}}; + if (composed_call->outer_func != return_id) { + return composed_call; + } + if (composed_call->inner_func != identity) { + return composed_call; + } + if (composed_call->args.size() != 1) { + return composed_call; + } + return composed_call->args.at(0); + }); + } + + value_type Convert(const AnfExpr& anf_expr) { + return anf_expr.Match( + [&](const Atomic& atomic_expr) { + return ConvertAtomic(atomic_expr); + }, + [&](const Combined& combined_expr) { + return ConvertCombined(combined_expr); + }, + [&](const Let& let_expr) { return ConvertLet(let_expr); }); + } + + LazyCoreExpr TryWrapperToLazyCoreExpr( + const MaybeLazyCoreExpr& maybe_lazy_core_expr) { + return maybe_lazy_core_expr.Match( + [&](const LazyCoreExpr& lazy) { return lazy; }, + [&](const CoreExpr& core_expr) { + PADDLE_ENFORCE_EQ( + core_expr.Has>(), + true, + phi::errors::InvalidArgument( + "core_expr should return a Atomic instance")); + const Atomic val = core_expr.Get>(); + return LazyCoreExpr([val](const Atomic& continuation) { + CoreExprBuilder core{}; + return core.ComposedCallAtomic( + continuation, Symbol{builtin_symbol::Id{}}, {val}); + }); + }); + } + + value_type ConvertAtomic(const Atomic& atomic_expr) { + return atomic_expr.Match( + [&](const tVar& var) { return ConvertVar(var); }, + [&](const adt::Nothing) { return ConvertNothing(); }, + [&](bool c) { return ConvertBool(c); }, + [&](int64_t c) { return ConvertInt64(c); }, + [&](double c) { return ConvertDouble(c); }, + [&](const std::string& c) { return ConvertString(c); }, + [&](const Lambda& lambda) { return ConvertLambda(lambda); }); + } + + value_type ConvertCombined(const Combined& combined_expr) { + return combined_expr.Match( + [&](const Call& call_expr) { return ConvertCall(call_expr); }, + [&](const If& if_expr) { return ConvertIf(if_expr); }); + } + + value_type ConvertVar(const tVar& var) { + const auto& opt_symbol = builtin_symbol::GetSymbolFromString(var.value()); + return CoreVal(opt_symbol.Match( + [&](const builtin_symbol::Symbol& symbol) -> Symbol { return symbol; }, + [&](const adt::Nothing&) -> Symbol { return var; })); + } + value_type ConvertNothing() { return CoreVal(core_.None()); } + value_type ConvertBool(const bool c) { return CoreVal(core_.Bool(c)); } + value_type ConvertInt64(const int64_t c) { return CoreVal(core_.Int64(c)); } + value_type ConvertDouble(const double c) { return CoreVal(core_.Double(c)); } + value_type ConvertString(const std::string& c) { + return CoreVal(core_.String(c)); + } + value_type ConvertLambda(const Lambda& anf_expr) { + const auto& core_body_val = Convert(anf_expr->body); + LazyCoreExpr lazy_core_expr = TryWrapperToLazyCoreExpr(core_body_val); + CoreExpr core_body = lazy_core_expr(core_.Var(kBuiltinReturn())); + return CoreVal(core_.Lambda(anf_expr->args, core_body)); + } + + value_type ConvertCall(const Call& anf_expr) { + const auto& inner_func = ConvertAtomicToAtomic(anf_expr->func); + std::vector> core_args{}; + core_args.reserve(anf_expr->args.size()); + for (const auto& arg : anf_expr->args) { + core_args.push_back(ConvertAtomicToAtomic(arg)); + } + return LazyCoreVal( + [inner_func, core_args](const Atomic& continuation) { + CoreExprBuilder core{}; + return core.ComposedCallAtomic(continuation, inner_func, core_args); + }); + } + value_type ConvertIf(const If& anf_expr) { + const Atomic& core_cond = ConvertAtomicToAtomic(anf_expr->cond); + const auto& MakeZeroArgLambda = [](const auto& expr_ptr) { + return AnfExprBuilder().Lambda({}, expr_ptr); + }; + const Atomic& core_true_expr = + ConvertAtomicToAtomic(MakeZeroArgLambda(anf_expr->true_expr)); + const Atomic& core_false_expr = + ConvertAtomicToAtomic(MakeZeroArgLambda(anf_expr->false_expr)); + return LazyCoreVal([=](const Atomic& continuation) { + CoreExprBuilder core{}; + return core.ComposedCallAtomic( + continuation, + core.Var("if"), + {core_cond, core_true_expr, core_false_expr}); + }); + } + value_type ConvertLet(const Let& anf_expr) { + std::vector symbol_names; + std::vector lazy_core_exprs; + lazy_core_exprs.reserve(anf_expr->bindings.size()); + for (const auto& binding : anf_expr->bindings) { + symbol_names.push_back(binding.var.value()); + lazy_core_exprs.push_back(ConvertCombinedToLazyCoreExpr(binding.val)); + } + value_type body_val = Convert(anf_expr->body); + LazyCoreExpr body_lazy_core_expr = TryWrapperToLazyCoreExpr(body_val); + lazy_core_exprs.push_back(body_lazy_core_expr); + PADDLE_ENFORCE_EQ( + lazy_core_exprs.size(), + symbol_names.size() + 1, + phi::errors::InvalidArgument( + "lazy_core_exprs.size() should equal to symbol_names.size() + 1")); + return LazyCoreVal( + [symbol_names, lazy_core_exprs](Atomic continuation) { + CoreExprBuilder core{}; + LazyCoreExpr first_body_lazy_core_expr = lazy_core_exprs.at(0); + for (int i = lazy_core_exprs.size() - 1; i > 0; i--) { + const auto& var = symbol_names.at(i - 1); + LazyCoreExpr lazy_core_expr = lazy_core_exprs.at(i); + CoreExpr body = lazy_core_expr(continuation); + continuation = core.Lambda({tVar{var}}, body); + } + return first_body_lazy_core_expr(continuation); + }); + } + + private: + void CheckIsAtomic(const value_type& maybe_lazy_core_expr) { + PADDLE_ENFORCE_EQ(maybe_lazy_core_expr.Has(), + true, + phi::errors::InvalidArgument( + "ConvertAtomic should return a CoreExpr instance")); + const auto& core_expr = maybe_lazy_core_expr.Get(); + PADDLE_ENFORCE_EQ( + core_expr.Has>(), + true, + phi::errors::InvalidArgument( + "ConvertAtomic should return a Atomic instance")); + } + + Atomic GetAtomic(const value_type& val) { + return val.Get().Get>(); + } + + Atomic ConvertAtomicToAtomic(const Atomic& atomic_anf) { + value_type val = ConvertAtomic(atomic_anf); + CheckIsAtomic(val); + return GetAtomic(val); + } + + void CheckIsLazyCoreExpr(const value_type& maybe_lazy_core_expr) { + PADDLE_ENFORCE_EQ( + maybe_lazy_core_expr.Has(), + true, + phi::errors::InvalidArgument( + "ConvertCombined should return a LazyCoreExpr instance")); + } + + LazyCoreExpr GetLazyCoreExpr(const value_type& val) { + return val.Get(); + } + + LazyCoreExpr ConvertCombinedToLazyCoreExpr( + const Combined& combined_anf) { + value_type val = ConvertCombined(combined_anf); + CheckIsLazyCoreExpr(val); + return GetLazyCoreExpr(val); + } + + CoreExprBuilder core_; +}; + +} // namespace detail + +inline CoreExpr ConvertAnfExprToCoreExpr(const AnfExpr& anf_expr) { + return detail::AnfExprToCoreExprConverter().ConvertAnfExprToCoreExpr( + anf_expr); +} + +namespace detail { + +using adt::Result; + +using Json = nlohmann::json; + +inline adt::errors::Error JsonParseFailed(const Json& j_obj, + const std::string& msg) { + return adt::errors::TypeError{msg + " json: " + j_obj.dump()}; +} + +inline adt::errors::Error JsonParseMismatch(const Json& j_obj, + const std::string& msg) { + return adt::errors::MismatchError{msg}; +} + +typedef Result (*JsonParseFuncType)(const Json& j_obj); + +Result ConvertJsonToAnfExpr(const Json& j_obj); + +struct ParseJsonToAnfExprHelperVar { + static Result Call(const Json& j_obj) { + if (!j_obj.is_string()) { + return JsonParseMismatch(j_obj, + "ParseJsonToAnfExpr>: json " + "objects should be strings"); + } + std::string str = j_obj.get(); + return AnfExpr{AnfExprBuilder().Var(str)}; + } +}; + +struct ParseJsonToAnfExprHelperNull { + static Result Call(const Json& j_obj) { + if (!j_obj.is_null()) { + return JsonParseMismatch( + j_obj, "ParseJsonToAnfExpr: json object should be null."); + } + return AnfExpr{AnfExprBuilder().None()}; + } +}; + +struct ParseJsonToAnfExprHelperBool { + static Result Call(const Json& j_obj) { + if (!j_obj.is_boolean()) { + return JsonParseMismatch( + j_obj, "ParseJsonToAnfExpr: json object should be a boolean."); + } + bool c = j_obj.get(); + return AnfExpr{AnfExprBuilder().Bool(c)}; + } +}; + +struct ParseJsonToAnfExprHelperInt64 { + static Result Call(const Json& j_obj) { + if (!j_obj.is_number_integer()) { + return JsonParseMismatch(j_obj, + "ParseJsonToAnfExpr: json object " + "should be a intergral number."); + } + auto c = j_obj.get(); + return AnfExpr{AnfExprBuilder().Int64(c)}; + } +}; + +struct ParseJsonToAnfExprHelperDouble { + static Result Call(const Json& j_obj) { + if (!j_obj.is_number_float()) { + return JsonParseMismatch(j_obj, + "ParseJsonToAnfExpr: json object should " + "be a floating point number."); + } + auto c = j_obj.template get(); + return AnfExpr{AnfExprBuilder().Double(c)}; + } +}; + +struct ParseJsonToAnfExprHelperString { + static Result Call(const Json& j_obj) { + if (!j_obj.is_object()) { + return JsonParseMismatch(j_obj, + "ParseJsonToAnfExpr: an string " + "AnfExpr should be a json object."); + } + if (!j_obj.contains(AnfExpr::kString())) { + return JsonParseMismatch(j_obj, + "ParseJsonToAnfExpr: an string " + "AnfExpr should contain a string."); + } + if (j_obj.size() != 1) { + return JsonParseFailed(j_obj, + "ParseJsonToAnfExpr: length of json " + "object should equal to 1."); + } + if (!j_obj[AnfExpr::kString()].is_string()) { + return JsonParseFailed( + j_obj, + "ParseJsonToAnfExpr: an string AnfExpr " + "should contain a string."); + } + auto c = j_obj[AnfExpr::kString()].get(); + return AnfExpr{AnfExprBuilder().String(c)}; + } +}; + +struct ParseJsonToAnfExprHelperLambdaAnfExpr { + static Result Call(const Json& j_obj) { + if (!j_obj.is_array()) { + return JsonParseMismatch(j_obj, + "ParseJsonToAnfExpr>: json " + "objects should be arrays."); + } + if (j_obj.size() != 3) { + return JsonParseMismatch(j_obj, + "ParseJsonToAnfExpr>: length of " + "json array should equal to 3."); + } + if (j_obj.at(0) != AnfExpr::kLambda()) { + return JsonParseMismatch( + j_obj, + "ParseJsonToAnfExpr>: the first " + "element of json array should equal to 'lambda'."); + } + if (!j_obj.at(1).is_array()) { + return JsonParseFailed(j_obj, + "ParseJsonToAnfExpr>: the second " + "element of json array should be a list."); + } + std::vector> args; + for (const auto& arg : j_obj.at(1)) { + if (!arg.is_string()) { + return JsonParseFailed(j_obj, + "ParseJsonToAnfExpr>: lambda " + "arguments should be var names."); + } + args.emplace_back(arg.get()); + } + const auto& body = ConvertJsonToAnfExpr(j_obj.at(2)); + if (!body.HasOkValue()) { + return JsonParseFailed(j_obj, + "ParseJsonToAnfExpr>: the lambda " + "body should be a valid AnfExpr."); + } + return AnfExpr{AnfExprBuilder().Lambda(args, body.GetOkValue())}; + } +}; + +struct ParseJsonToAnfExprHelperCallAnfExpr { + static Result Call(const Json& j_obj) { + if (!j_obj.is_array()) { + return JsonParseMismatch( + j_obj, + "ParseJsonToAnfExpr>: json objects should be arrays."); + } + if (j_obj.empty()) { + return JsonParseFailed(j_obj, + "ParseJsonToAnfExpr>: json arrays " + "should be not empty."); + } + const auto& func = ConvertJsonToAnfExpr(j_obj.at(0)); + if (!func.HasOkValue()) { + return JsonParseFailed(j_obj, + "ParseJsonToAnfExpr>: the function " + "should a valid AnfExpr."); + } + if (!func.GetOkValue().Has>()) { + return JsonParseFailed(j_obj, + "ParseJsonToAnfExpr>: the function " + "should a valid atomic AnfExpr."); + } + std::vector> args; + for (int i = 1; i < j_obj.size(); ++i) { + const auto& arg = j_obj.at(i); + const auto& arg_expr = ConvertJsonToAnfExpr(arg); + if (!arg_expr.HasOkValue()) { + return JsonParseFailed(j_obj, + "ParseJsonToAnfExpr>: the args " + "should be valid AnfExprs."); + } + if (!arg_expr.GetOkValue().Has>()) { + return JsonParseFailed(j_obj, + "ParseJsonToAnfExpr>: the args " + "should be valid atomic AnfExprs."); + } + args.push_back(arg_expr.GetOkValue().Get>()); + } + return AnfExpr{ + AnfExprBuilder().Call(func.GetOkValue().Get>(), args)}; + } +}; + +struct ParseJsonToAnfExprHelperIfAnfExpr { + static Result Call(const Json& j_obj) { + if (!j_obj.is_array()) { + return JsonParseMismatch(j_obj, + "ParseJsonToAnfExpr>: json objects " + "should be valid atomic AnfExprs."); + } + if (j_obj.size() != 4) { + return JsonParseMismatch(j_obj, + "ParseJsonToAnfExpr>: the length of " + "json array should equal to 4."); + } + if (j_obj.at(0) != AnfExpr::kIf()) { + return JsonParseMismatch(j_obj, + "ParseJsonToAnfExpr>: the first " + "argument of json array should equal to 'if'."); + } + const auto& cond = ConvertJsonToAnfExpr(j_obj.at(1)); + if (!cond.HasOkValue()) { + return JsonParseFailed(j_obj, + "ParseJsonToAnfExpr>: the second " + "argument of json array should a valid AnfExpr."); + } + if (!cond.GetOkValue().Has>()) { + return JsonParseFailed( + j_obj, + "ParseJsonToAnfExpr>: the second argument of json array " + "should a valid atomic AnfExpr."); + } + const auto& cond_expr = cond.GetOkValue().Get>(); + const auto& true_expr = ConvertJsonToAnfExpr(j_obj.at(2)); + if (!true_expr.HasOkValue()) { + return JsonParseFailed(j_obj, + "ParseJsonToAnfExpr>: the third " + "argument of json array should a valid AnfExpr."); + } + const auto& false_expr = ConvertJsonToAnfExpr(j_obj.at(3)); + if (!false_expr.HasOkValue()) { + return JsonParseFailed(j_obj, + "ParseJsonToAnfExpr>: the forth " + "argument of json array should a valid AnfExpr."); + } + return AnfExpr{AnfExprBuilder().If( + cond_expr, true_expr.GetOkValue(), false_expr.GetOkValue())}; + } +}; + +struct ParseJsonToAnfExprHelperLetAnfExpr { + static Result Call(const Json& j_obj) { + if (!j_obj.is_array()) { + return JsonParseMismatch( + j_obj, + "ParseJsonToAnfExpr>: json objects should be arrays."); + } + if (j_obj.size() != 3) { + return JsonParseMismatch( + j_obj, + "ParseJsonToAnfExpr>: the length of " + "json array should equal to 3."); + } + if (j_obj.at(0) != AnfExpr::kLet()) { + return JsonParseMismatch(j_obj, + "ParseJsonToAnfExpr>: the first " + "argument of json array should be 'let'."); + } + std::vector> bindings; + const auto& j_bindings = j_obj.at(1); + for (int i = 0; i < j_bindings.size(); ++i) { + const auto& binding = j_bindings.at(i); + if (!binding.is_array()) { + return JsonParseFailed(binding, + "ParseJsonToAnfExpr>: bindings " + "should be json arrays."); + } + if (binding.size() != 2) { + return JsonParseFailed(binding, + "ParseJsonToAnfExpr>: the size of " + "one binding should equal to 2."); + } + if (!binding.at(0).is_string()) { + return JsonParseFailed(binding.at(0), + "ParseJsonToAnfExpr>: the first " + "element of a binding should be var name."); + } + std::string var = binding.at(0).get(); + const auto& val = ConvertJsonToAnfExpr(binding.at(1)); + if (!val.HasOkValue()) { + return JsonParseFailed( + binding.at(1), + "ParseJsonToAnfExpr>: the second " + "element of a binding should be a valid AnfExpr."); + } + if (!val.GetOkValue().Has>()) { + return JsonParseFailed( + binding.at(1), + "ParseJsonToAnfExpr>: the second element of a binding " + "should be a valid combined AnfExpr."); + } + bindings.push_back(AnfExprBuilder().Bind( + var, val.GetOkValue().Get>())); + } + const auto& body = ConvertJsonToAnfExpr(j_obj.at(2)); + if (!body.HasOkValue()) { + return JsonParseFailed( + j_obj.at(2), + "ParseJsonToAnfExpr>: the body of Let " + "AnfExpr should be a valid AnfExpr."); + } + return AnfExpr{AnfExprBuilder().Let(bindings, body.GetOkValue())}; + } +}; + +inline const std::vector& GetJsonParseFuncs() { + static const std::vector vec{ + &ParseJsonToAnfExprHelperLambdaAnfExpr::Call, + &ParseJsonToAnfExprHelperIfAnfExpr::Call, + &ParseJsonToAnfExprHelperLetAnfExpr::Call, + &ParseJsonToAnfExprHelperCallAnfExpr::Call, + &ParseJsonToAnfExprHelperVar::Call, + &ParseJsonToAnfExprHelperNull::Call, + &ParseJsonToAnfExprHelperBool::Call, + &ParseJsonToAnfExprHelperInt64::Call, + &ParseJsonToAnfExprHelperDouble::Call, + &ParseJsonToAnfExprHelperString::Call, + }; + return vec; +} + +inline Result ConvertJsonToAnfExpr(const Json& j_obj) { + try { + for (const auto& parse_func : GetJsonParseFuncs()) { + const auto& ret = parse_func(j_obj); + if (ret.HasOkValue()) { + return ret.GetOkValue(); + } + if (!ret.GetError().Has()) { + LOG(ERROR) << "\nTraceback (most recent call last):\n" + << ret.GetError().CallStackToString() << "\n" + << ret.GetError().class_name() + << ": ConvertJsonToAnfExpr: " << ret.GetError().msg(); + return ret.GetError(); + } + } + } catch (std::exception& e) { + return JsonParseFailed(j_obj, + "ConvertJsonToAnfExpr: throw error when parsing."); + } + return JsonParseFailed(j_obj, "ConvertJsonToAnfExpr: failed to convert."); +} + +inline Result MakeAnfExprFromJsonString(const std::string& json_str) { + try { + return detail::ConvertJsonToAnfExpr(Json::parse(json_str)); + } catch (std::exception& e) { + return adt::errors::InvalidArgumentError{ + std::string() + "json parse failed. exception::what():" + e.what()}; + } +} + +} // namespace detail + +inline adt::Result MakeAnfExprFromJsonString( + const std::string& json_str) { + return detail::MakeAnfExprFromJsonString(json_str); +} + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/atomic.h b/paddle/ap/include/axpr/atomic.h new file mode 100644 index 00000000000000..0ed828bc7372c1 --- /dev/null +++ b/paddle/ap/include/axpr/atomic.h @@ -0,0 +1,74 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/common/overloaded.h" + +namespace ap::axpr { + +template +struct LambdaImpl { + std::vector> args; + Expr body; + + bool operator==(const LambdaImpl& other) const { + return (this->args == other.args) && (this->body == other.body); + } +}; + +template +ADT_DEFINE_RC(Lambda, const LambdaImpl); + +// aexpr := Var | CONST | (lambda [VAR] expr) + +template +struct ExprSymbolTrait { + using symbol_type = tVar; +}; + +template +using AtomicBase = std::variant::symbol_type, + adt::Nothing, + bool, + int64_t, + double, + std::string, + Lambda>; + +template +struct Atomic : public AtomicBase { + using AtomicBase::AtomicBase; + ADT_DEFINE_VARIANT_METHODS(AtomicBase); +}; + +template +struct CallImpl { + Atomic func; + std::vector> args; + + bool operator==(const CallImpl& other) const { + return (this->func == other.func) && (this->args == other.args); + } +}; + +template +ADT_DEFINE_RC(Call, const CallImpl); + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/atomic_builder.h b/paddle/ap/include/axpr/atomic_builder.h new file mode 100644 index 00000000000000..f75a7e1838fdef --- /dev/null +++ b/paddle/ap/include/axpr/atomic_builder.h @@ -0,0 +1,50 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/atomic.h" + +namespace ap::axpr { + +template +class AtomicExprBuilder { + public: + AtomicExprBuilder() {} + AtomicExprBuilder(const AtomicExprBuilder&) = delete; + AtomicExprBuilder(AtomicExprBuilder&&) = delete; + + Atomic Var(const std::string& name) { + return Atomic{tVar{name}}; + } + + Atomic Bool(bool c) { return Atomic{c}; } + + Atomic Int64(int64_t c) { return Atomic{c}; } + + Atomic Double(double c) { return Atomic{c}; } + + Atomic None() { return Atomic{adt::Nothing{}}; } + + Atomic String(const std::string& str) { return Atomic{str}; } + + Atomic Lambda(const std::vector>& args, + const Expr& body) { + return Atomic{ap::axpr::Lambda{args, body}}; + } + + private: +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/attr_map.h b/paddle/ap/include/axpr/attr_map.h new file mode 100644 index 00000000000000..9eaa103d73b008 --- /dev/null +++ b/paddle/ap/include/axpr/attr_map.h @@ -0,0 +1,95 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +struct AttrMapImpl { + std::unordered_map storage; + + size_t size() const { return storage.size(); } + + void clear() { storage.clear(); } + + Result Get(const std::string& var) const { + const auto& iter = storage.find(var); + if (iter == storage.end()) { + return AttributeError{"object has no attribute '" + var + "'"}; + } + return iter->second; + } + + bool Has(const std::string& var) const { + return storage.find(var) != storage.end(); + } + + template + Result Get(const std::string& var) const { + ADT_LET_CONST_REF(val, this->Get(var)); + ADT_CHECK(val.template Has()); + return val.template Get(); + } + + template + Result TryGet(const std::string& var) const { + return this->template Get(var); + } + + template + Result> OptGet(const std::string& var) const { + if (!this->Has(var)) { + return std::nullopt; + } + ADT_LET_CONST_REF(val, this->Get(var)); + ADT_CHECK(val.template Has()); + return val.template Get(); + } + + std::optional OptGet(const std::string& var) const { + const auto& iter = storage.find(var); + if (iter == storage.end()) { + return std::nullopt; + } + return iter->second; + } + + void Set(const std::string& var, const ValueT& val) { + this->storage[var] = val; + } + + bool Emplace(const std::string& var, const ValueT& val) { + return this->storage.emplace(var, val).second; + } + + bool operator==(const AttrMapImpl& other) const { return &other == this; } +}; + +template +ADT_DEFINE_RC(AttrMap, AttrMapImpl); + +template +struct TypeImpl> : public std::monostate { + using value_type = AttrMap; + + const char* Name() const { return "AttrMap"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/attr_map_method_class.h b/paddle/ap/include/axpr/attr_map_method_class.h new file mode 100644 index 00000000000000..546b0491698daf --- /dev/null +++ b/paddle/ap/include/axpr/attr_map_method_class.h @@ -0,0 +1,61 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/to_string.h" + +namespace ap::axpr { + +template +struct AttrMapMethodClass { + using This = AttrMapMethodClass; + using Self = AttrMap; + + adt::Result ToString(axpr::InterpreterBase* interpreter, + const Self& self) { + std::ostringstream ss; + ss << "AttrMap("; + int i = 0; + for (const auto& [k, v] : self->storage) { + if (i++ > 0) { + ss << ", "; + } + ADT_LET_CONST_REF(value_str, axpr::ToString(interpreter, v)); + ss << k << "=" << value_str; + } + ss << ")"; + return ss.str(); + } + + adt::Result GetAttr(const Self& self, const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + ADT_LET_CONST_REF(val, self->Get(attr_name)) << adt::errors::AttributeError{ + std::string() + "'object' has no attribute '" + attr_name + "'."}; + return val; + } +}; + +template +struct MethodClassImpl> + : public AttrMapMethodClass {}; + +template +struct MethodClassImpl>> + : public EmptyMethodClass {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/binary_func.h b/paddle/ap/include/axpr/binary_func.h new file mode 100644 index 00000000000000..66abcb90463e46 --- /dev/null +++ b/paddle/ap/include/axpr/binary_func.h @@ -0,0 +1,64 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace ap::axpr { + +#define PEXPR_FOR_EACH_BINARY_OP(_) \ + _(Add, +) \ + _(Sub, -) \ + _(Mul, *) \ + _(Div, /) \ + _(Mod, %) \ + _(EQ, ==) \ + _(NE, !=) \ + _(GT, >) \ + _(GE, >=) \ + _(LT, <) \ + _(LE, <=) + +#define DEFINE_ARITHMETIC_BINARY_OP(name, op) \ + struct Arithmetic##name { \ + static constexpr const char* Name() { return #op; } \ + \ + template \ + static auto Call(const LhsT& lhs, const RhsT& rhs) { \ + return lhs op rhs; \ + } \ + }; +PEXPR_FOR_EACH_BINARY_OP(DEFINE_ARITHMETIC_BINARY_OP); +#undef DEFINE_ARITHMETIC_BINARY_OP + +template +struct BoolIntDoubleBinary { + static constexpr const char* Name() { return ArithmeticOp::Name(); } + template + static auto Call(LhsT lhs, RhsT rhs) { + auto ret = ArithmeticOp::Call(lhs, rhs); + using T = decltype(ret); + if constexpr (std::is_same_v) { + return ret; + } else if constexpr (std::is_integral_v) { + return static_cast(ret); + } else { + static_assert(std::is_floating_point::value, ""); + return static_cast(ret); + } + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/bool.h b/paddle/ap/include/axpr/bool.h new file mode 100644 index 00000000000000..eec20bf8f92fea --- /dev/null +++ b/paddle/ap/include/axpr/bool.h @@ -0,0 +1,30 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template <> +struct TypeImpl : public std::monostate { + using value_type = bool; + + const char* Name() const { return "bool"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/bool_helper.h b/paddle/ap/include/axpr/bool_helper.h new file mode 100644 index 00000000000000..763a8e9e308941 --- /dev/null +++ b/paddle/ap/include/axpr/bool_helper.h @@ -0,0 +1,59 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/value.h" + +namespace ap::axpr { + +struct BoolHelper { + adt::Result ConvertToBool(const axpr::Value& cond) { + using TypeT = typename TypeTrait::TypeT; + return cond.Match( + [](const TypeT&) -> Result { return true; }, + [](const bool c) -> Result { return c; }, + [](const int64_t c) -> Result { return c != 0; }, + [](const double c) -> Result { return c != 0; }, + [](const std::string& c) -> Result { return !c.empty(); }, + [](const Nothing&) -> Result { return false; }, + [](const adt::List& list) -> Result { + return list->size() > 0; + }, + [](const MutableList& list) -> Result { + ADT_LET_CONST_REF(list_ptr, list.Get()); + return list_ptr->size() > 0; + }, + [](const AttrMap& obj) -> Result { + return obj->size() > 0; + }, + [](const Lambda&) -> Result { return true; }, + [](const Closure&) -> Result { return true; }, + [](const Continuation&) -> Result { return true; }, + [](const Method&) -> Result { return true; }, + [](const builtin_symbol::Symbol&) -> Result { return true; }, + [](const BuiltinFuncType&) -> Result { + return true; + }, + [](const BuiltinHighOrderFuncType&) -> Result { + return true; + }, + [&](const auto&) -> Result { + return TypeError{std::string() + "'" + axpr::GetTypeName(cond) + + "' could not be convert to bool"}; + }); + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/bool_int_double.h b/paddle/ap/include/axpr/bool_int_double.h new file mode 100644 index 00000000000000..3406735751f3d9 --- /dev/null +++ b/paddle/ap/include/axpr/bool_int_double.h @@ -0,0 +1,41 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" + +namespace ap::axpr { + +using BoolIntDoubleImpl = std::variant; + +struct BoolIntDouble : public BoolIntDoubleImpl { + using BoolIntDoubleImpl::BoolIntDoubleImpl; + + ADT_DEFINE_VARIANT_METHODS(BoolIntDoubleImpl); + + template + static adt::Result CastFrom(const ValueT& value) { + using RetT = adt::Result; + return value.Match( + [](bool c) -> RetT { return c; }, + [](int64_t c) -> RetT { return c; }, + [](double c) -> RetT { return c; }, + [](const auto&) -> RetT { + return adt::errors::ValueError{"BoolIntDouble::CastFrom() failed."}; + }); + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/bool_int_double_arithmetic_util.h b/paddle/ap/include/axpr/bool_int_double_arithmetic_util.h new file mode 100644 index 00000000000000..af7cb77f7e867a --- /dev/null +++ b/paddle/ap/include/axpr/bool_int_double_arithmetic_util.h @@ -0,0 +1,52 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/binary_func.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/unary_func.h" + +namespace ap::axpr { + +template +Result BoolIntDoubleArithmeticUnaryFunc(const T& value) { + return BoolIntDoubleUnary::Call(value); +} + +template +Result BoolIntDoubleArithmeticBinaryFunc(const T0& lhs, const T1& rhs) { + if constexpr (std::is_same_v) { + if (rhs == 0) { + return adt::errors::ZeroDivisionError{"division by zero"}; + } + } + if constexpr (std::is_same_v) { + if (rhs == 0) { + return adt::errors::ZeroDivisionError{"modulo by zero"}; + } + if constexpr (std::is_floating_point::value // NOLINT + || std::is_floating_point::value) { + return std::fmod(lhs, rhs); + } else { + return BoolIntDoubleBinary::Call(lhs, rhs); + } + } else { + return BoolIntDoubleBinary::Call(lhs, rhs); + } +} + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/bool_int_double_helper.h b/paddle/ap/include/axpr/bool_int_double_helper.h new file mode 100644 index 00000000000000..d2e1c44540123b --- /dev/null +++ b/paddle/ap/include/axpr/bool_int_double_helper.h @@ -0,0 +1,37 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/bool_int_double.h" +#include "paddle/ap/include/axpr/bool_int_double_arithmetic_util.h" + +namespace ap::axpr { + +template +struct BoolIntDoubleHelper { + template + static adt::Result BinaryFunc(const BoolIntDouble& lhs_val, + const BoolIntDouble& rhs_val) { + const auto& pattern_match = ::common::Overloaded{ + [&](const auto lhs, const auto rhs) -> adt::Result { + return BoolIntDoubleArithmeticBinaryFunc(lhs, + rhs); + }}; + return std::visit(pattern_match, lhs_val.variant(), rhs_val.variant()); + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/bool_method_class.h b/paddle/ap/include/axpr/bool_method_class.h new file mode 100644 index 00000000000000..c581d2fd980d62 --- /dev/null +++ b/paddle/ap/include/axpr/bool_method_class.h @@ -0,0 +1,135 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/bool_int_double_arithmetic_util.h" +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/method_class.h" + +namespace ap::axpr { + +template +struct BoolMethodClass { + using This = BoolMethodClass; + + using Self = bool; + + adt::Result ToString(Self val) { + return std::string(val ? "True" : "False"); + } + + adt::Result Hash(Self val) { return static_cast(val); } + + template + static BuiltinUnaryFunc GetBuiltinUnaryFunc() { + if constexpr (ConvertBuiltinSymbolToArithmetic< + BuiltinUnarySymbol>::convertible) { + using ArithmeticOp = typename ConvertBuiltinSymbolToArithmetic< + BuiltinUnarySymbol>::arithmetic_op_type; + return &This::UnaryFunc; + } else { + return adt::Nothing{}; + } + } + + template + static BuiltinBinaryFunc GetBuiltinBinaryFunc() { + if constexpr (ConvertBuiltinSymbolToArithmetic< + BuiltinBinarySymbol>::convertible) { + using ArithmeticOp = typename ConvertBuiltinSymbolToArithmetic< + BuiltinBinarySymbol>::arithmetic_op_type; + return &This::template BinaryFunc; + } else { + return adt::Nothing{}; + } + } + + template + static adt::Result BinaryFunc(const ValueT& lhs_val, + const ValueT& rhs_val) { + const auto& opt_lhs = lhs_val.template TryGet(); + ADT_RETURN_IF_ERR(opt_lhs); + bool lhs = opt_lhs.GetOkValue(); + return rhs_val.Match( + [&](bool rhs) -> adt::Result { + return BoolIntDoubleArithmeticBinaryFunc(lhs, + rhs); + }, + [&](int64_t rhs) -> adt::Result { + return BoolIntDoubleArithmeticBinaryFunc(lhs, + rhs); + }, + [&](double rhs) -> adt::Result { + return BoolIntDoubleArithmeticBinaryFunc(lhs, + rhs); + }, + [&](const auto& impl) -> adt::Result { + using T = std::decay_t; + return adt::errors::TypeError{ + std::string() + "unsupported operand type(s) for " + + ArithmeticOp::Name() + ": 'bool' and '" + + axpr::GetTypeName(rhs_val) + "'"}; + }); + } + + template + static adt::Result UnaryFunc(const ValueT& val) { + ADT_LET_CONST_REF(operand, val.template TryGet()); + return BoolIntDoubleArithmeticUnaryFunc(operand); + } +}; + +template +struct MethodClassImpl : public BoolMethodClass {}; + +template +struct MethodClassImpl> { + using This = MethodClassImpl>; + + adt::Result Call(const TypeImpl&) { return &This::Construct; } + + static adt::Result Construct(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "bool() takes 1 argument, but " + + std::to_string(args.size()) + " were given"}; + using T = bool; + using RetT = adt::Result; + return args.at(0).Match( + [&](bool c) -> RetT { return static_cast(c); }, + [&](int64_t c) -> RetT { return static_cast(c); }, + [&](double c) -> RetT { return static_cast(c); }, + [&](DataValue data_value) -> RetT { + return data_value.Match( + [&](const axpr::pstring&) -> RetT { + return adt::errors::TypeError{ + "invalid conversion from type 'pstring' to 'bool'"}; + }, + [&](const adt::Undefined&) -> RetT { + return adt::errors::TypeError{ + "invalid conversion from type 'void' to 'bool'"}; + }, + [&](const auto& impl) -> RetT { return static_cast(impl); }); + }, + [&](const auto&) -> adt::Result { + return adt::errors::TypeError{ + std::string() + + "the argument 1 of bool() should be bool/int/float/DataValue"}; + }); + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_class_instance.h b/paddle/ap/include/axpr/builtin_class_instance.h new file mode 100644 index 00000000000000..596546066062af --- /dev/null +++ b/paddle/ap/include/axpr/builtin_class_instance.h @@ -0,0 +1,111 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "glog/logging.h" +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/builtin_func_type.h" +#include "paddle/ap/include/axpr/class_attrs.h" +#include "paddle/ap/include/axpr/class_ops.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +struct BuiltinClassInstance; + +template +struct TypeImpl> { + TypeImpl>( + const ClassOps* class_ops) // NOLINT + : class_ops_(class_ops) {} + + const ClassOps* class_ops_; + + const ClassOps* class_ops() const { return class_ops_; } + + const ClassAttrsImpl* class_attrs() const { + return class_ops_->class_attrs(); + } + + ValueT New(const std::any& any) const; + + const std::string& Name() const { return class_attrs()->Name(); } + + bool operator==(const TypeImpl>& other) const { + return this->class_ops_ == other.class_ops_; + } +}; + +template +struct BuiltinClassInstance { + TypeImpl> type; + std::any instance; + + template + bool Has() const { + return this->instance.type() == typeid(T); + } + + template + adt::Result TryGet() const { + if (this->template Has()) { + return std::any_cast(this->instance); + } else { + return adt::errors::TypeError{ + std::string() + "casting from " + type.Name() + + " class (cpp class name: " + instance.type().name() + ") to " + + typeid(T).name() + " failed."}; + } + } + + bool operator==(const BuiltinClassInstance& other) const { + const auto* class_ops = this->type.class_ops(); + const auto& ret = class_ops->Equals(*this, other); + CHECK(ret.HasOkValue()) + << "\nTraceback (most recent call last):\n" + << ret.GetError().CallStackToString() << "\n" + << ret.GetError().class_name() + << ": BuiltinClassInstance::operator()(): " << ret.GetError().msg(); + return ret.GetOkValue(); + } +}; + +template +ValueT TypeImpl>::New(const std::any& any) const { + return BuiltinClassInstance{*this, any}; +} + +template +using BuiltinFrameValImpl = std::variant, + BuiltinHighOrderFuncType, + typename TypeTrait::TypeT>; + +template +ClassAttrs MakeBuiltinClass(const std::string& class_name, + const VisitorT& Visitor) { + AttrMap attr_map; + Visitor([&](const auto& name, const auto& val) { + using TestType = decltype(BuiltinFrameValImpl{val}); + attr_map->Set(name, val); + }); + adt::List>> empty_superclasses{}; + return ClassAttrs{class_name, empty_superclasses, attr_map}; +} + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_class_instance_method_class.h b/paddle/ap/include/axpr/builtin_class_instance_method_class.h new file mode 100644 index 00000000000000..66b280e0a82a7b --- /dev/null +++ b/paddle/ap/include/axpr/builtin_class_instance_method_class.h @@ -0,0 +1,241 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/builtin_func_type.h" +#include "paddle/ap/include/axpr/builtin_high_order_func_type.h" +#include "paddle/ap/include/axpr/class_attrs_helper.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/method.h" +#include "paddle/ap/include/axpr/method_class.h" + +namespace ap::axpr { + +template +struct MethodClassImpl> { + using Val = ValueT; + using Self = BuiltinClassInstance; + using This = MethodClassImpl; + + adt::Result Hash(InterpreterBase* interpreter, + const Self& self) { + const auto& opt_func = GetClassAttr(self, "__hash__"); + ADT_CHECK(opt_func.has_value()) << adt::errors::TypeError{ + std::string() + self.type.class_attrs()->class_name + + ".__hash__() not implemented"}; + using RetT = adt::Result; + static std::vector empty_args{}; + return opt_func.value().Match( + [&](BuiltinFuncType unary_func) -> RetT { + return unary_func(self, empty_args); + }, + [&](BuiltinHighOrderFuncType unary_func) -> RetT { + return unary_func(interpreter, self, empty_args); + }, + [&](const axpr::Method& method) -> RetT { + return interpreter->InterpretCall(method, {}); + }, + [&](const auto&) -> RetT { + return adt::errors::TypeError{ + std::string() + "casting to builtin function (not " + + GetTypeName(opt_func.value()) + ") failed."}; + }); + } + + adt::Result ToString(InterpreterBase* interpreter, + const Self& self) { + const auto& opt_func = GetClassAttr(self, "__str__"); + ADT_CHECK(opt_func.has_value()) << adt::errors::TypeError{ + std::string() + self.type.class_attrs()->class_name + + ".__str__() not implemented"}; + using RetT = adt::Result; + static std::vector empty_args{}; + return opt_func.value().Match( + [&](BuiltinFuncType unary_func) -> RetT { + return unary_func(self, empty_args); + }, + [&](BuiltinHighOrderFuncType unary_func) -> RetT { + return unary_func(interpreter, self, empty_args); + }, + [&](const axpr::Method& method) -> RetT { + return interpreter->InterpretCall(method, {}); + }, + [&](const auto&) -> RetT { + return adt::errors::TypeError{ + std::string() + "casting to builtin function (not " + + GetTypeName(opt_func.value()) + ") failed."}; + }); + } + + adt::Result GetAttr(InterpreterBase* interpreter, + const Self& self, + const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + const auto& opt_val = GetClassAttr(self, attr_name); + if (opt_val.has_value()) { + return opt_val.value(); + } + const auto& opt_gettattr = GetClassAttr(self, "__getattr__"); + const auto& class_attrs = self.type.class_attrs(); + ADT_CHECK(opt_gettattr.has_value()) + << adt::errors::AttributeError{std::string() + class_attrs->class_name + + " class has no attribute '__getattr__'"}; + std::vector args{attr_name_val}; + ADT_LET_CONST_REF(ret, + interpreter->InterpretCall(opt_gettattr.value(), args)); + return ret; + } + + adt::Result EQ(const Self& self, const ValueT& rhs_val) { + ADT_LET_CONST_REF(ret, Equals(self, rhs_val)); + return ret; + } + + adt::Result NE(const Self& self, const ValueT& rhs_val) { + ADT_LET_CONST_REF(ret, Equals(self, rhs_val)); + return !ret; + } + + adt::Result Equals(const Self& self, const ValueT& rhs_val) { + const auto* class_ops = self.type.class_ops(); + return class_ops->Equals(self, rhs_val); + } + + adt::Result GetItem(InterpreterBase* interpreter, + const Self& self, + const ValueT& idx_val) { + const auto& opt_getitem = GetClassAttr(self, "__getitem__"); + const auto& class_attrs = self.type.class_attrs(); + ADT_CHECK(opt_getitem.has_value()) + << adt::errors::AttributeError{std::string() + class_attrs->class_name + + " class has no attribute '__getitem__'"}; + std::vector args{idx_val}; + ADT_LET_CONST_REF(ret, + interpreter->InterpretCall(opt_getitem.value(), args)); + return ret; + } + + adt::Result Call(const Self& self) { + const auto& opt_func = GetClassAttr(self, "__call__"); + const auto& class_attrs = self.type.class_attrs(); + ADT_CHECK(opt_func.has_value()) + << adt::errors::AttributeError{std::string() + class_attrs->class_name + + " class has no attribute '__call__'"}; + return opt_func.value(); + } + + adt::Result Starred(const Self& self) { + const auto& opt_func = GetClassAttr(self, "__starred__"); + const auto& class_attrs = self.type.class_attrs(); + ADT_CHECK(opt_func.has_value()) + << adt::errors::AttributeError{std::string() + class_attrs->class_name + + " class has no attribute '__starred__'"}; + return opt_func.value(); + } + + std::optional GetClassAttr(const Self& self, + const std::string& attr_name) { + const auto& class_attrs = self.type.class_attrs(); + const auto& opt_func = + ClassAttrsHelper{}.OptGet(class_attrs, attr_name); + if (!opt_func.has_value()) { + return std::nullopt; + } + using RetT = ValueT; + return opt_func.value().Match( + [&](const BuiltinFuncType& f) -> RetT { + return Method{self, f}; + }, + [&](const BuiltinHighOrderFuncType& f) -> RetT { + return Method{self, f}; + }, + [&](const auto&) -> RetT { return opt_func.value(); }); + } + + adt::Result SetAttr(const Self& self, const ValueT& attr_name_val) { + const auto& class_attrs = self.type.class_attrs(); + const auto& opt_func = GetClassAttr(self, "__setattr__"); + ADT_CHECK(opt_func.has_value()) + << adt::errors::AttributeError{std::string() + class_attrs->class_name + + " class has no attribute '__setattr__'"}; + return opt_func.value(); + } + + adt::Result SetItem(const Self& self, const ValueT& idx_val) { + const auto& class_attrs = self.type.class_attrs(); + const auto& opt_func = GetClassAttr(self, "__setitem__"); + ADT_CHECK(opt_func.has_value()) + << adt::errors::AttributeError{std::string() + class_attrs->class_name + + " class has no attribute '__setitem__'"}; + return opt_func.value(); + } +}; + +template +struct MethodClassImpl>> { + using Val = ValueT; + using Self = TypeImpl>; + using This = MethodClassImpl; + + adt::Result GetAttr(const Self& self, const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + ADT_LET_CONST_REF(attr_val, self.class_attrs()->attrs->Get(attr_name)) + << adt::errors::AttributeError{ + std::string() + "type '" + self.class_attrs()->class_name + + "' has no attribute '" + attr_name + "'"}; + return attr_val; + } + + adt::Result Call(const Self& self) { + ValueT func{&This::StaticConstruct}; + return Method{self, func}; + } + + adt::Result ToString(const Self& self) { + return std::string() + "class_name + "'>"; + } + + adt::Result Hash(const Self& self) { + return reinterpret_cast(self.class_attrs()); + } + + static adt::Result StaticConstruct( + axpr::InterpreterBase* interpreter, + const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, axpr::TryGetTypeImpl(self_val)); + return This{}.Construct(interpreter, self, args); + } + + adt::Result Construct(axpr::InterpreterBase* interpreter, + const Self& self, + const std::vector& args) { + const auto* class_ops = self.class_ops(); + const auto& class_attrs = class_ops->class_attrs(); + TypeImpl> type(class_ops); + BuiltinClassInstance empty_instance{type, std::nullopt}; + const auto& init_func = + ClassAttrsHelper{}.OptGet(class_attrs, "__init__"); + ADT_CHECK(init_func.has_value()) + << adt::errors::TypeError{std::string() + class_attrs->class_name + + " class has no __init__ function"}; + Method f{empty_instance, init_func.value()}; + ADT_LET_CONST_REF(ret_instance, interpreter->InterpretCall(f, args)); + return ret_instance; + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_classes.h b/paddle/ap/include/axpr/builtin_classes.h new file mode 100644 index 00000000000000..311791b3f97e15 --- /dev/null +++ b/paddle/ap/include/axpr/builtin_classes.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" + +namespace ap::axpr { + +template +adt::Result VisitEachBuiltinClass(const DoEachT& DoEach) { + return adt::Ok{}; +} + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_environment.h b/paddle/ap/include/axpr/builtin_environment.h new file mode 100644 index 00000000000000..c189470ae07dfa --- /dev/null +++ b/paddle/ap/include/axpr/builtin_environment.h @@ -0,0 +1,50 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/environment.h" + +namespace ap::axpr { + +template +class BuiltinEnvironment : public Environment { + public: + explicit BuiltinEnvironment(const AttrMap& builtin_object) + : builtin_object_(builtin_object) {} + + adt::Result Get(const std::string& var) const override { + return builtin_object_->Get(var); + } + + adt::Result Set(const std::string& var, const ValueT& val) override { + return adt::errors::RuntimeError{"builtin environment is immutable."}; + } + + std::optional> RecursivelyGetConstGlobalFrame() + const override { + return std::nullopt; + } + + private: + BuiltinEnvironment(const BuiltinEnvironment&) = delete; + BuiltinEnvironment(BuiltinEnvironment&&) = delete; + + AttrMap builtin_object_; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_frame_util.h b/paddle/ap/include/axpr/builtin_frame_util.h new file mode 100644 index 00000000000000..0e7362608527b7 --- /dev/null +++ b/paddle/ap/include/axpr/builtin_frame_util.h @@ -0,0 +1,60 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/builtin_functions.h" +#include "paddle/ap/include/axpr/builtin_symbol.h" +#include "paddle/ap/include/axpr/exception_method_class.h" +#include "paddle/ap/include/axpr/module_mgr_helper.h" + +namespace ap::axpr { + +template +void VisitEachBuiltinFrameAttr(const YieldT& Yield) { + AttrMap base{ValueT::GetExportedTypes()}; + for (const auto& [k, v] : base->storage) { + Yield(k, v); + } + Yield("import", &ModuleMgrHelper::ImportModule); + Yield("apply", &Apply); + Yield("print", &Print); + Yield("replace_or_trim_left_comma", &ReplaceOrTrimLeftComma); + Yield("range", &MakeRange); + Yield("flat_map", &FlatMap); + Yield("map", &Map); + Yield("filter", &Filter); + Yield("reduce", &Reduce); + Yield("zip", &Zip); + Yield("max", &Max); + Yield("min", &Min); + Yield("len", &Length); + Yield("getattr", &GetAttr); + Yield("setattr", &SetAttr); + ForEachExceptionConstructor(Yield); + Yield("raise", &Raise); + Yield("__builtin_not__", &BuiltinNot); +} + +template +AttrMap MakeBuiltinFrameAttrMap() { + AttrMap attr_map; + VisitEachBuiltinFrameAttr( + [&](const std::string& k, const ValueT& v) { attr_map->Set(k, v); }); + return attr_map; +} + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_func_name_mgr.h b/paddle/ap/include/axpr/builtin_func_name_mgr.h new file mode 100644 index 00000000000000..909ec2aeabbda1 --- /dev/null +++ b/paddle/ap/include/axpr/builtin_func_name_mgr.h @@ -0,0 +1,75 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "glog/logging.h" + +namespace ap::axpr { + +struct BuiltinFuncName { + std::optional module_name; + std::string func_name; + + std::string ToString() const { + return OptStrToStr(module_name) + "." + func_name; + } + + static std::string OptStrToStr(const std::optional& opt_str) { + if (opt_str.has_value()) return opt_str.value(); + return "__builtin_frame__"; + } +}; + +class BuiltinFuncNameMgr { + public: + bool Has(void* ptr) const { return func_ptr2name_.count(ptr) > 0; } + + std::optional OptGet(void* ptr) const { + const auto& iter = func_ptr2name_.find(ptr); + if (iter == func_ptr2name_.end()) return std::nullopt; + return iter->second; + } + + void Register(const std::optional& module_name, + const std::string& func_name, + void* func_ptr) { + CHECK(func_ptr2name_ + .emplace(func_ptr, BuiltinFuncName{module_name, func_name}) + .second) + << "redundant name for builtin function: old_module_name: " + << ToString(func_ptr2name_[func_ptr].module_name) + << ", old_func_name: " << func_ptr2name_[func_ptr].func_name + << ", new_module_name: " << ToString(module_name) + << ", new_func_name: " << func_name; + } + + static BuiltinFuncNameMgr* Singleton() { + static BuiltinFuncNameMgr mgr{}; + return &mgr; + } + + private: + BuiltinFuncNameMgr() {} + + static std::string ToString(const std::optional& opt_str) { + return BuiltinFuncName::OptStrToStr(opt_str); + } + + std::unordered_map func_ptr2name_; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_func_type.h b/paddle/ap/include/axpr/builtin_func_type.h new file mode 100644 index 00000000000000..625861446310ab --- /dev/null +++ b/paddle/ap/include/axpr/builtin_func_type.h @@ -0,0 +1,74 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/interpreter_base.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +using BuiltinFuncType = Result (*)(const ValueT&, + const std::vector& args); + +template (This::*func)( + const typename This::Self&, + const std::vector& args)> +Result WrapAsBuiltinFuncType( + const typename This::Val& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template TryGet()); + return (This{}.*func)(self, args); +} + +template +struct TypeImpl> : public std::monostate { + using value_type = BuiltinFuncType; + + const char* Name() const { return "builtin_function"; } +}; + +template +using BuiltinHighOrderFuncType = + Result (*)(InterpreterBase* interpreter, + const ValueT& obj, + const std::vector& args); + +template +struct TypeImpl> : public std::monostate { + using value_type = BuiltinHighOrderFuncType; + + const char* Name() const { return "builtin_high_order_function"; } +}; + +template +using BuiltinFunctionImpl = + std::variant, BuiltinHighOrderFuncType>; + +template +struct BuiltinFunction : public BuiltinFunctionImpl { + using BuiltinFunctionImpl::BuiltinFunctionImpl; + ADT_DEFINE_VARIANT_METHODS(BuiltinFunctionImpl); + + template + T CastTo() const { + return Match([](const auto& impl) -> T { return impl; }); + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_func_type_method_class.h b/paddle/ap/include/axpr/builtin_func_type_method_class.h new file mode 100644 index 00000000000000..247adb34d259e4 --- /dev/null +++ b/paddle/ap/include/axpr/builtin_func_type_method_class.h @@ -0,0 +1,47 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/builtin_func_type.h" +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/method_class.h" + +namespace ap::axpr { + +template +struct BuiltinFuncTypeMethodClass { + using This = BuiltinFuncTypeMethodClass; + using Self = BuiltinFuncType; + + adt::Result ToString(Self func) { + std::ostringstream ss; + ss << "<" << TypeImpl{}.Name() << " object at " << func << ">"; + return ss.str(); + } + + adt::Result Hash(Self func) { + return reinterpret_cast(func); + } +}; + +template +struct MethodClassImpl> + : public BuiltinFuncTypeMethodClass {}; + +template +struct MethodClassImpl>> + : public EmptyMethodClass {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_functions.h b/paddle/ap/include/axpr/builtin_functions.h new file mode 100644 index 00000000000000..9af1e61b573f33 --- /dev/null +++ b/paddle/ap/include/axpr/builtin_functions.h @@ -0,0 +1,90 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/value.h" + +namespace ap::axpr { + +Result BuiltinIdentity(const axpr::Value&, + const std::vector& args); + +Result BuiltinNot(const axpr::Value&, + const std::vector& args); + +Result Raise(const axpr::Value&, + const std::vector& args); + +Result BuiltinList(const axpr::Value&, + const std::vector& args); + +Result BuiltinHalt(const axpr::Value&, + const std::vector& args); + +adt::Result Print(InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args); + +adt::Result ReplaceOrTrimLeftComma( + const axpr::Value&, const std::vector& args); + +adt::Result MakeRange(const axpr::Value&, + const std::vector& args); + +Result FlatMap(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args); + +Result Map(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args); + +Result Apply(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args); + +Result Length(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args); + +Result Filter(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args); + +Result Zip(const axpr::Value&, + const std::vector& args); + +Result Reduce(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args); + +Result Max(const axpr::Value&, + const std::vector& args); + +Result Min(const axpr::Value&, + const std::vector& args); + +Result Min(const axpr::Value&, + const std::vector& args); + +Result GetAttr(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args); + +Result SetAttr(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args); +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_high_order_func_type.h b/paddle/ap/include/axpr/builtin_high_order_func_type.h new file mode 100644 index 00000000000000..8938b810dfd432 --- /dev/null +++ b/paddle/ap/include/axpr/builtin_high_order_func_type.h @@ -0,0 +1,17 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/builtin_func_type.h" diff --git a/paddle/ap/include/axpr/builtin_high_order_func_type_method_class.h b/paddle/ap/include/axpr/builtin_high_order_func_type_method_class.h new file mode 100644 index 00000000000000..31a03452fccbeb --- /dev/null +++ b/paddle/ap/include/axpr/builtin_high_order_func_type_method_class.h @@ -0,0 +1,42 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/builtin_high_order_func_type.h" +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/method_class.h" + +namespace ap::axpr { + +template +struct MethodClassImpl> { + using Self = BuiltinHighOrderFuncType; + using This = MethodClassImpl; + + adt::Result ToString(Self func) { + std::ostringstream ss; + ss << "<" << TypeImpl{}.Name() << " object at " << func << ">"; + return ss.str(); + } + + adt::Result Hash(Self func) { + return reinterpret_cast(func); + } +}; + +template +struct MethodClassImpl>> {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_serializable_attr_map.h b/paddle/ap/include/axpr/builtin_serializable_attr_map.h new file mode 100644 index 00000000000000..6ede9498e038a2 --- /dev/null +++ b/paddle/ap/include/axpr/builtin_serializable_attr_map.h @@ -0,0 +1,31 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template <> +struct TypeImpl> : public std::monostate { + using value_type = AttrMap; + + const char* Name() const { return "BuiltinSerializableAttrMap"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_serializable_attr_map_method_class.h b/paddle/ap/include/axpr/builtin_serializable_attr_map_method_class.h new file mode 100644 index 00000000000000..6b0d79a1689d62 --- /dev/null +++ b/paddle/ap/include/axpr/builtin_serializable_attr_map_method_class.h @@ -0,0 +1,91 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/builtin_serializable_attr_map.h" +#include "paddle/ap/include/axpr/class_instance.h" +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/packed_args.h" +#include "paddle/ap/include/axpr/serializable_value_helper.h" + +namespace ap::axpr { + +template +struct BuiltinSerializableAttrMapMethodClass { + using This = BuiltinSerializableAttrMapMethodClass; + using Self = AttrMap; + + adt::Result Length(const Self& self) { + return static_cast(self->size()); + } + + adt::Result ToString(const Self& self) { + ADT_LET_CONST_REF(str, SerializableValueHelper{}.ToString(self)); + return str; + } + + adt::Result Hash(const Self& self) { + ADT_LET_CONST_REF(hash_value, SerializableValueHelper{}.Hash(self)); + return hash_value; + } + + adt::Result GetAttr(const Self& self, const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + ADT_LET_CONST_REF(val, self->Get(attr_name)) << adt::errors::AttributeError{ + std::string() + "'BuiltinSerializableAttrMap' has no attribute '" + + attr_name + "'."}; + return val.template CastTo(); + } +}; + +template +struct TypeImplBuiltinSerializableAttrMapMethodClass { + using This = TypeImplBuiltinSerializableAttrMapMethodClass; + using Self = TypeImpl>; + + adt::Result Call(const Self&) { return &This::StaticConstruct; } + + static adt::Result StaticConstruct(const ValueT&, + const std::vector& args) { + return This{}.Construct(args); + } + + adt::Result Construct(const std::vector& args) { + const auto& packed_args = CastToPackedArgs(args); + const auto& [pos_args, kwargs] = *packed_args; + ADT_CHECK(pos_args->empty()) + << adt::errors::TypeError{std::string() + + "the construct of BuiltinSerializableAttrMap " + "takes no positional arguments."}; + ADT_LET_CONST_REF(serializable_val, + SerializableValueHelper{}.CastObjectFrom(kwargs)); + ADT_LET_CONST_REF( + serializable_obj, + serializable_val.template TryGet>()); + return serializable_obj; + } +}; + +template +struct MethodClassImpl> + : public BuiltinSerializableAttrMapMethodClass {}; + +template +struct MethodClassImpl>> + : public TypeImplBuiltinSerializableAttrMapMethodClass {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_serializable_attr_map_to_axpr_helper.h b/paddle/ap/include/axpr/builtin_serializable_attr_map_to_axpr_helper.h new file mode 100644 index 00000000000000..1fb29622e6f06d --- /dev/null +++ b/paddle/ap/include/axpr/builtin_serializable_attr_map_to_axpr_helper.h @@ -0,0 +1,140 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/value.h" + +namespace ap::axpr { + +struct BuiltinSerializableAttrMapToAxprHelper { + using AnfExpr = axpr::AnfExpr; + + adt::Result Convert( + ap::axpr::LetContext* ctx, + const ap::axpr::AttrMap& attr_map) const { + return GetCodeFromBuiltinSerializableAttrMap(ctx, attr_map); + } + + private: + adt::Result GetCodeFromBuiltinSerializableAttrMap( + ap::axpr::LetContext* ctx, + const ap::axpr::AttrMap& attr_map) const { + std::map kwargs; + for (const auto& [keyword, val] : attr_map->storage) { + ADT_LET_CONST_REF(val_anf, + GetCodeFromBuiltinSerializableAttrMapItem(ctx, val)); + kwargs[keyword] = val_anf; + } + return ctx->Apply("BuiltinSerializableAttrMap", {}, kwargs); + } + + adt::Result GetCodeFromBuiltinSerializableAttrMapItem( + ap::axpr::LetContext* ctx, + const ap::axpr::SerializableValue& item) const { + return item.Match( + [&](const adt::Nothing&) -> adt::Result { + return ctx->None(); + }, + [&](bool c) -> adt::Result { return ctx->Bool(c); }, + [&](int64_t c) -> adt::Result { return ctx->Int64(c); }, + [&](double c) -> adt::Result { return ctx->Double(c); }, + [&](const std::string& str) -> adt::Result { + return ctx->String(str); + }, + [&](const adt::List& l) + -> adt::Result { + return GetCodeFromBuiltinSerializableAttrMapList(ctx, l); + }, + [&](const ap::axpr::AttrMap& object) + -> adt::Result { + return GetCodeFromBuiltinSerializableAttrMap(ctx, object); + }, + [&](const ap::axpr::Function& function) + -> adt::Result { + const auto& lambda = function->lambda; + const AnfExpr& anf_expr = ap::axpr::ConvertCoreExprToAnfExpr(lambda); + AnfExpr ret{ctx->Attr(anf_expr, "__function__")}; + return ret; + }, + [&](const axpr::TypeImpl& impl) -> adt::Result { + return ctx->Var(impl.Name()); + }, + [&](const axpr::TypeImpl& impl) -> adt::Result { + return ctx->Var(impl.Name()); + }, + [&](const axpr::TypeImpl& impl) -> adt::Result { + return ctx->Var(impl.Name()); + }, + [&](const axpr::TypeImpl& impl) -> adt::Result { + return ctx->Var(impl.Name()); + }, + [&](const axpr::TypeImpl& impl) -> adt::Result { + return ctx->Var(impl.Name()); + }, + [&](const axpr::ClassAttrs&) + -> adt::Result { + return adt::errors::NotImplementedError{ + "serialization of axpr::ClassAttrs not " + "implemented"}; + }, + [&](const axpr::BuiltinFuncVoidPtr& func) -> adt::Result { + const auto& name_info = + axpr::BuiltinFuncNameMgr::Singleton()->OptGet(func.func_ptr); + ADT_CHECK(name_info.has_value()); + if (name_info.value().module_name.has_value()) { + const auto& module_name = + ctx->String(name_info.value().module_name.value()); + const auto& func_name = name_info.value().func_name; + return ctx->Var("import").Call(module_name).Attr(func_name); + } else { + const auto& func_name = name_info.value().func_name; + return ctx->Var(func_name); + } + }, + [&](const axpr::BuiltinHighOrderFuncVoidPtr& func) + -> adt::Result { + const auto& name_info = + axpr::BuiltinFuncNameMgr::Singleton()->OptGet(func.func_ptr); + ADT_CHECK(name_info.has_value()); + if (name_info.value().module_name.has_value()) { + const auto& module_name = + ctx->String(name_info.value().module_name.value()); + const auto& func_name = name_info.value().func_name; + return ctx->Var("import").Call(module_name).Attr(func_name); + } else { + const auto& func_name = name_info.value().func_name; + return ctx->Var(func_name); + } + }); + } + + adt::Result GetCodeFromBuiltinSerializableAttrMapList( + ap::axpr::LetContext* ctx, + const adt::List& list) const { + std::vector elt_anf_exprs; + for (const auto& elt : *list) { + ADT_LET_CONST_REF(elt_anf_expr, + GetCodeFromBuiltinSerializableAttrMapItem(ctx, elt)); + elt_anf_exprs.emplace_back(elt_anf_expr); + } + return ctx->Call(ap::axpr::kBuiltinList(), elt_anf_exprs); + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_symbol.h b/paddle/ap/include/axpr/builtin_symbol.h new file mode 100644 index 00000000000000..25774d7947c87e --- /dev/null +++ b/paddle/ap/include/axpr/builtin_symbol.h @@ -0,0 +1,279 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/binary_func.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/unary_func.h" + +namespace ap::axpr { + +inline constexpr const char* kBuiltinIf() { return "__builtin_if__"; } +inline constexpr const char* kBuiltinIdentity() { + return "__builtin_identity__"; +} +inline constexpr const char* kBuiltinList() { return "__builtin_list__"; } +inline constexpr const char* kBuiltinStarred() { return "__builtin_starred__"; } +inline constexpr const char* kBuiltinCall() { return "__builtin_call__"; } +inline constexpr const char* kBuiltinToString() { + return "__builtin_ToString__"; +} +inline constexpr const char* kBuiltinHash() { return "__builtin_hash__"; } +inline constexpr const char* kBuiltinGetAttr() { return "__builtin_getattr__"; } +inline constexpr const char* kBuiltinSetAttr() { return "__builtin_setattr__"; } +inline constexpr const char* kBuiltinGetItem() { return "__builtin_getitem__"; } +inline constexpr const char* kBuiltinSetItem() { return "__builtin_setitem__"; } +inline constexpr const char* kBuiltinLength() { return "__builtin_len__"; } +inline constexpr const char* kBuiltinReturn() { return "__builtin_return__"; } + +#define DEFINE_PEXPR_BUILTIN_CONSTANT_NAME(name, op) \ + inline constexpr const char* kBuiltin##name() { \ + return "__builtin_" #name "__"; \ + } +PEXPR_FOR_EACH_BINARY_OP(DEFINE_PEXPR_BUILTIN_CONSTANT_NAME) +PEXPR_FOR_EACH_UNARY_OP(DEFINE_PEXPR_BUILTIN_CONSTANT_NAME) +#undef DEFINE_PEXPR_BUILTIN_CONSTANT_NAME + +namespace builtin_symbol { + +struct If : public std::monostate { + using std::monostate::monostate; + static constexpr const char* Name() { return kBuiltinIf(); } + std::size_t GetHashValue() const { return 0; } +}; + +struct Id : public std::monostate { + using std::monostate::monostate; + static constexpr const char* Name() { return kBuiltinIdentity(); } + std::size_t GetHashValue() const { return 0; } +}; + +struct List : public std::monostate { + using std::monostate::monostate; + static constexpr const char* Name() { return kBuiltinList(); } + std::size_t GetHashValue() const { return 0; } +}; + +struct Starred : public std::monostate { + using std::monostate::monostate; + static constexpr const char* Name() { return kBuiltinStarred(); } + static constexpr int num_operands = 1; + std::size_t GetHashValue() const { return 0; } +}; + +struct Call : public std::monostate { + using std::monostate::monostate; + static constexpr const char* Name() { return kBuiltinCall(); } + static constexpr int num_operands = 1; + std::size_t GetHashValue() const { return 0; } +}; + +struct ToString : public std::monostate { + using std::monostate::monostate; + static constexpr const char* Name() { return kBuiltinToString(); } + static constexpr int num_operands = 1; + std::size_t GetHashValue() const { return 0; } +}; + +struct Hash : public std::monostate { + using std::monostate::monostate; + static constexpr const char* Name() { return kBuiltinHash(); } + static constexpr int num_operands = 1; + std::size_t GetHashValue() const { return 0; } +}; + +struct GetAttr : public std::monostate { + using std::monostate::monostate; + static constexpr const char* Name() { return kBuiltinGetAttr(); } + static constexpr int num_operands = 2; + std::size_t GetHashValue() const { return 0; } +}; + +struct SetAttr : public std::monostate { + using std::monostate::monostate; + static constexpr const char* Name() { return kBuiltinSetAttr(); } + static constexpr int num_operands = 2; + std::size_t GetHashValue() const { return 0; } +}; + +struct GetItem : public std::monostate { + using std::monostate::monostate; + static constexpr const char* Name() { return kBuiltinGetItem(); } + static constexpr int num_operands = 2; + std::size_t GetHashValue() const { return 0; } +}; + +struct SetItem : public std::monostate { + using std::monostate::monostate; + static constexpr const char* Name() { return kBuiltinSetItem(); } + static constexpr int num_operands = 2; + std::size_t GetHashValue() const { return 0; } +}; + +struct Length : public std::monostate { + using std::monostate::monostate; + static constexpr const char* Name() { return kBuiltinLength(); } + static constexpr int num_operands = 1; + std::size_t GetHashValue() const { return 0; } +}; + +#define DEFINE_UNARY_SYMBOL(name, op) \ + struct name : public std::monostate { \ + using std::monostate::monostate; \ + static constexpr const char* Name() { return kBuiltin##name(); } \ + static constexpr int num_operands = 1; \ + std::size_t GetHashValue() const { return 0; } \ + }; + +PEXPR_FOR_EACH_UNARY_OP(DEFINE_UNARY_SYMBOL); + +#undef DEFINE_UNARY_SYMBOL; + +#define DEFINE_BINARY_SYMBOL(name, op) \ + struct name : public std::monostate { \ + using std::monostate::monostate; \ + static constexpr const char* Name() { return kBuiltin##name(); } \ + static constexpr int num_operands = 2; \ + std::size_t GetHashValue() const { return 0; } \ + }; + +PEXPR_FOR_EACH_BINARY_OP(DEFINE_BINARY_SYMBOL); + +#undef DEFINE_BINARY_SYMBOL; + +#define AXPR_FOR_EACH_SYMBOL_OP(_) \ + PEXPR_FOR_EACH_BINARY_OP(_) \ + PEXPR_FOR_EACH_UNARY_OP(_) \ + _(Call, ()) \ + _(ToString, str) \ + _(Hash, hash) \ + _(Starred, *) \ + _(GetAttr, .) \ + _(SetAttr, .) \ + _(GetItem, []) \ + _(SetItem, []) \ + _(Length, len) + +using OpImpl = std::variant< +#define MAKE_OP_IMPL_ALTERNATIVE(name, op) name, + PEXPR_FOR_EACH_BINARY_OP(MAKE_OP_IMPL_ALTERNATIVE) + PEXPR_FOR_EACH_UNARY_OP(MAKE_OP_IMPL_ALTERNATIVE) +#undef MAKE_OP_IMPL_ALTERNATIVE + Call, + ToString, + Hash, + Starred, + GetAttr, + SetAttr, + GetItem, + SetItem, + Length>; + +struct Op : public OpImpl { + using OpImpl::OpImpl; + ADT_DEFINE_VARIANT_METHODS(OpImpl); + + const char* Name() const { + return Match([](const auto& impl) { return impl.Name(); }); + } + + std::size_t GetHashValue() const { + std::size_t hash_value = + Match([&](const auto& impl) { return impl.GetHashValue(); }); + return adt::hash_combine(hash_value, this->index()); + } +}; + +using SymbolImpl = std::variant; + +struct Symbol : public SymbolImpl { + using SymbolImpl::SymbolImpl; + ADT_DEFINE_VARIANT_METHODS(SymbolImpl); + + const char* Name() const { + return Match([](const auto& impl) { return impl.Name(); }); + } + + std::size_t GetHashValue() const { + std::size_t hash_value = + Match([&](const auto& impl) { return impl.GetHashValue(); }); + return adt::hash_combine(hash_value, this->index()); + } +}; + +inline adt::Maybe GetSymbolFromString(const std::string& name) { + static const std::unordered_map map{ + {If::Name(), If{}}, + {Id::Name(), Id{}}, + {List::Name(), List{}}, + {Call::Name(), Op{Call{}}}, + {ToString::Name(), Op{ToString{}}}, + {Hash::Name(), Op{Hash{}}}, + {Starred::Name(), Op{Starred{}}}, + {GetAttr::Name(), Op{GetAttr{}}}, + {SetAttr::Name(), Op{SetAttr{}}}, + {GetItem::Name(), Op{GetItem{}}}, + {SetItem::Name(), Op{SetItem{}}}, + {Length::Name(), Op{Length{}}}, +#define MAKE_SYMBOL_ENTRY(cls, op) {cls::Name(), Op{cls{}}}, + PEXPR_FOR_EACH_BINARY_OP(MAKE_SYMBOL_ENTRY) + PEXPR_FOR_EACH_UNARY_OP(MAKE_SYMBOL_ENTRY) +#undef MAKE_SYMBOL_ENTRY + }; + const auto& iter = map.find(name); + if (iter == map.end()) { + return adt::Nothing{}; + } + return iter->second; +} + +} // namespace builtin_symbol + +template +struct ConvertBuiltinSymbolToArithmetic { + static const bool convertible = false; + using arithmetic_op_type = void; +}; + +#define SPECIALIZE_ConvertBuiltinSymbolToArithmetic(cls, op) \ + template <> \ + struct ConvertBuiltinSymbolToArithmetic { \ + static const bool convertible = true; \ + using arithmetic_op_type = Arithmetic##cls; \ + }; + +PEXPR_FOR_EACH_BINARY_OP(SPECIALIZE_ConvertBuiltinSymbolToArithmetic); +PEXPR_FOR_EACH_UNARY_OP(SPECIALIZE_ConvertBuiltinSymbolToArithmetic); +#undef SPECIALIZE_ConvertBuiltinSymbolToArithmetic + +template +constexpr const char* GetBuiltinSymbolDebugString() { + if constexpr (ConvertBuiltinSymbolToArithmetic::convertible) { + return ConvertBuiltinSymbolToArithmetic< + BuiltinSymbol>::arithmetic_op_type::Name(); + } else { + return BuiltinSymbol::Name(); + } +} + +template <> +struct TypeImpl : public std::monostate { + using value_type = builtin_symbol::Symbol; + + const char* Name() const { return "builtin_symbol"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/builtin_symbol_method_class.h b/paddle/ap/include/axpr/builtin_symbol_method_class.h new file mode 100644 index 00000000000000..f64fdba33c1bc6 --- /dev/null +++ b/paddle/ap/include/axpr/builtin_symbol_method_class.h @@ -0,0 +1,45 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/builtin_symbol.h" +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/method_class.h" + +namespace ap::axpr { + +template +struct BuiltinSymbolMethodClass { + using This = BuiltinSymbolMethodClass; + using Self = builtin_symbol::Symbol; + + adt::Result ToString(const Self& self) { + return std::string(self.Name()); + } + + adt::Result Hash(const Self& self) { + return static_cast(std::hash()(self.Name())); + } +}; + +template +struct MethodClassImpl + : public BuiltinSymbolMethodClass {}; + +template +struct MethodClassImpl> + : public EmptyMethodClass {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/call_environment.h b/paddle/ap/include/axpr/call_environment.h new file mode 100644 index 00000000000000..fe316b57b09054 --- /dev/null +++ b/paddle/ap/include/axpr/call_environment.h @@ -0,0 +1,67 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/environment.h" +#include "paddle/ap/include/axpr/frame.h" + +namespace ap::axpr { + +template +class CallEnvironment : public Environment { + public: + CallEnvironment(const std::shared_ptr>& parent, + const Frame& frame) + : parent_(parent), frame_(frame) {} + + adt::Result Get(const std::string& var) const override { + ADT_LET_CONST_REF(frame_ptr, frame_.Get()); + const auto& res = frame_ptr->OptGet(var); + if (res.has_value()) { + return res.value(); + } + if (parent_ == nullptr) { + return NameError{std::string("name '") + var + "' is not defined."}; + } + return parent_->Get(var); + } + + adt::Result Set(const std::string& var, const ValueT& val) override { + ADT_LET_CONST_REF(frame_ptr, frame_.Mut()); + frame_ptr->Set(var, val); + return adt::Ok{}; + } + + std::optional> RecursivelyGetConstGlobalFrame() + const override { + if (parent_ == nullptr) { + return std::nullopt; + } + return parent_->RecursivelyGetConstGlobalFrame(); + } + + const Frame& frame() const { return frame_; } + + private: + CallEnvironment(const CallEnvironment&) = delete; + CallEnvironment(CallEnvironment&&) = delete; + + std::shared_ptr> parent_; + Frame frame_; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/callable_helper.h b/paddle/ap/include/axpr/callable_helper.h new file mode 100644 index 00000000000000..c59db2f06b548f --- /dev/null +++ b/paddle/ap/include/axpr/callable_helper.h @@ -0,0 +1,53 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/class_attrs_helper.h" +#include "paddle/ap/include/axpr/value.h" + +namespace ap::axpr { + +struct CallableHelper { + bool IsCallable(const axpr::Value& value) const { + return value.Match( + [&](const BuiltinFuncType& func) -> bool { return true; }, + [&](const BuiltinHighOrderFuncType& func) -> bool { + return true; + }, + [&](const Method& method) -> bool { return true; }, + [&](const Closure& closure) -> bool { return true; }, + [&](const Continuation& continuation) -> bool { + return true; + }, + [&](const Function& function) -> bool { + return true; + }, + [&](const builtin_symbol::Symbol& symbol) -> bool { return true; }, + [&](const BuiltinClassInstance& builtin_class_instance) + -> bool { + const auto* class_attrs = builtin_class_instance.type.class_attrs(); + ClassAttrsHelper helper{}; + return helper.OptGet(class_attrs, "__call__").has_value(); + }, + [&](const ClassInstance& class_instance) -> bool { + const auto& class_attrs = class_instance->type.class_attrs; + ClassAttrsHelper helper{}; + return helper.OptGet(class_attrs, "__call__").has_value(); + }, + [&](const auto&) -> bool { return false; }); + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/class_attrs.h b/paddle/ap/include/axpr/class_attrs.h new file mode 100644 index 00000000000000..15b06d1ae83b78 --- /dev/null +++ b/paddle/ap/include/axpr/class_attrs.h @@ -0,0 +1,36 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" + +namespace ap::axpr { + +template +struct ClassAttrsImpl { + std::string class_name; + adt::List> superclasses; + AttrMap attrs; + + const std::string& Name() const { return this->class_name; } + + bool operator==(const ClassAttrsImpl& other) const { return this == &other; } +}; + +template +ADT_DEFINE_RC(ClassAttrs, ClassAttrsImpl); + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/class_attrs_helper.h b/paddle/ap/include/axpr/class_attrs_helper.h new file mode 100644 index 00000000000000..40ae9970102ae6 --- /dev/null +++ b/paddle/ap/include/axpr/class_attrs_helper.h @@ -0,0 +1,56 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/class_instance.h" +#include "paddle/ap/include/axpr/serializable_value.h" + +namespace ap::axpr { + +template +struct ClassAttrsHelper { + std::optional OptGet(const ClassAttrs& class_attrs, + const std::string& attr_name) { + return ImplOptGet(class_attrs.shared_ptr().get(), attr_name); + } + + std::optional OptGet(const ClassAttrsImpl* class_attrs, + const std::string& attr_name) { + return ImplOptGet(class_attrs, attr_name); + } + + private: + std::optional ImplOptGet( + const ClassAttrsImpl* class_attrs_impl, + const std::string& attr_name) { + const auto& opt_val = class_attrs_impl->attrs->OptGet(attr_name); + if (opt_val.has_value()) { + if constexpr (std::is_same_v) { + return opt_val.value(); + } else { + return opt_val.value().template CastTo(); + } + } + for (const auto& base : *class_attrs_impl->superclasses) { + if (const auto val_in_base = ImplOptGet(base.get(), attr_name)) { + return val_in_base.value(); + } + } + return std::nullopt; + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/class_instance.h b/paddle/ap/include/axpr/class_instance.h new file mode 100644 index 00000000000000..0e53d9b9eed8d1 --- /dev/null +++ b/paddle/ap/include/axpr/class_instance.h @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/class_attrs.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/instance_attrs.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +struct ClassInstance; + +template +struct TypeImpl> { + explicit TypeImpl>( + const ClassAttrs& class_attr_val) + : class_attrs(class_attr_val) {} + + ClassAttrs class_attrs; + + const std::string& Name() const { return class_attrs->Name(); } + + bool operator==(const TypeImpl>& other) const { + return this->class_attrs == other.class_attrs; + } +}; + +template +struct ClassInstanceImpl { + TypeImpl> type; + InstanceAttrs instance_attrs; + + bool operator==(const ClassInstanceImpl& other) const { + return this == &other; + } +}; + +template +ADT_DEFINE_RC(ClassInstance, ClassInstanceImpl); + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/class_instance_method_class.h b/paddle/ap/include/axpr/class_instance_method_class.h new file mode 100644 index 00000000000000..543c493d5fe4aa --- /dev/null +++ b/paddle/ap/include/axpr/class_instance_method_class.h @@ -0,0 +1,202 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/builtin_high_order_func_type.h" +#include "paddle/ap/include/axpr/class_attrs_helper.h" +#include "paddle/ap/include/axpr/class_instance.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/method.h" +#include "paddle/ap/include/axpr/method_class.h" + +namespace ap::axpr { + +template +struct MethodClassImpl> { + using Val = ValueT; + using Self = ClassInstance; + using This = MethodClassImpl; + + adt::Result Hash(InterpreterBase* interpreter, + const Self& self) { + const auto& opt_func = GetClassAttr(self, "__hash__"); + if (!opt_func.has_value()) { + return reinterpret_cast(self.shared_ptr().get()); + } + std::vector args{self}; + ADT_LET_CONST_REF(hash_value, + interpreter->InterpretCall(opt_func.value(), args)); + ADT_CHECK(hash_value.template Has()) + << adt::errors::TypeError{"__hash__ method should return an integer"}; + return hash_value; + } + + adt::Result ToString(InterpreterBase* interpreter, + const Self& self) { + const auto& opt_func = GetClassAttr(self, "__str__"); + if (!opt_func.has_value()) { + std::ostringstream ss; + const auto* ptr = self.shared_ptr().get(); + ss << "<" << self->type.class_attrs->class_name << " object at " << ptr + << ">"; + return ss.str(); + } + std::vector args{self}; + ADT_LET_CONST_REF(str, interpreter->InterpretCall(opt_func.value(), args)); + ADT_CHECK(str.template Has()) + << adt::errors::TypeError{"__str__ method should return a str"}; + return str; + } + + adt::Result GetAttr(InterpreterBase* interpreter, + const Self& self, + const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + "type: '" + self->type.class_attrs->class_name + + "'. attr_name should be a str, but " + + axpr::GetTypeName(attr_name_val) + " were given"}; + ADT_LET_CONST_REF(instance_attrs, self->instance_attrs.Get()); + if (instance_attrs->Has(attr_name)) { + return instance_attrs->Get(attr_name); + } + const auto& opt_func = GetClassAttr(self, attr_name); + ADT_CHECK(opt_func.has_value()) << adt::errors::AttributeError{ + std::string() + "type object '" + self->type.class_attrs->class_name + + "' has no attribute '" + attr_name + "'"}; + if (opt_func.has_value()) { + return opt_func.value(); + } + const auto& opt_getter = GetClassAttr(self, "__getattr__"); + ADT_CHECK(opt_getter.has_value()) << adt::errors::AttributeError{ + std::string() + "type object '" + self->type.class_attrs->class_name + + "' has no attribute '__getattr__'"}; + std::vector args{attr_name_val}; + ADT_LET_CONST_REF(ret, + interpreter->InterpretCall(opt_getter.value(), args)); + return ret; + } + + adt::Result Call(const Self& self) { + const auto& opt_func = GetClassAttr(self, "__call__"); + ADT_CHECK(opt_func.has_value()) << adt::errors::AttributeError{ + std::string() + "type object '" + self->type.class_attrs->class_name + + "' has no attribute '__call__'"}; + return opt_func.value(); + } + + std::optional GetClassAttr(const Self& self, + const std::string& attr_name) { + const auto& class_attrs = self->type.class_attrs; + const auto& opt_func = ClassAttrsHelper{}.OptGet( + class_attrs, attr_name); + if (!opt_func.has_value()) { + return std::nullopt; + } + return opt_func.value().Match( + [&](const Function& f) -> ValueT { + return Method{self, f}; + }, + [&](const auto&) -> ValueT { return opt_func.value(); }); + } + + adt::Result SetAttr(const Self& self, const ValueT& attr_name_val) { + return Method{self, &This::SetInstanceAttr}; + } + + static adt::Result SetInstanceAttr(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + + "type(self) is unexpected. given: " + GetTypeName(self_val)}; + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(attr_name, args.at(0).template TryGet()) + << adt::errors::TypeError{ + std::string() + + "SetInstanceAttr() failed. args.at(0) should be a str. " + "type(self): " + + axpr::GetTypeName(self_val) + + ", type(args.at(0)): " + axpr::GetTypeName(args.at(0))}; + ADT_LET_CONST_REF(instance_attrs, self->instance_attrs.Mut()); + instance_attrs->Set(attr_name, args.at(1)); + return adt::Nothing{}; + } +}; + +template +struct MethodClassImpl>> { + using Val = ValueT; + using Self = TypeImpl>; + using This = MethodClassImpl; + + adt::Result GetAttr(const Self& self, const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + ADT_LET_CONST_REF(attr, self.class_attrs->attrs->Get(attr_name)) + << adt::errors::AttributeError{ + std::string() + "type object '" + self.class_attrs->class_name + + "' has no attribute '" + attr_name + "'"}; + return attr.template CastTo(); + } + + adt::Result Call(const Self& self) { + ValueT func{&This::StaticConstruct}; + return Method{self, func}; + } + + adt::Result ToString(const Self& self) { + return std::string() + "class_name + "'>"; + } + + adt::Result Hash(const Self& self) { + return reinterpret_cast(self.class_attrs.shared_ptr().get()); + } + + static adt::Result StaticConstruct( + axpr::InterpreterBase* interpreter, + const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, axpr::TryGetTypeImpl(self_val)); + return This{}.Construct(interpreter, self, args); + } + + adt::Result Construct(axpr::InterpreterBase* interpreter, + const Self& self, + const std::vector& args) { + const auto& class_attrs = self.class_attrs; + ADT_LET_CONST_REF(ref_lst, + adt::WeakPtrLock(interpreter->circlable_ref_list())); + const auto& instance = [&] { + const auto& instance_attrs = InstanceAttrs::Make( + ref_lst, std::make_shared>()); + TypeImpl> type(class_attrs); + return ClassInstance{type, instance_attrs}; + }(); + const auto& init_func = + ClassAttrsHelper{}.OptGet(class_attrs, + "__init__"); + if (init_func.has_value()) { + Method f{instance, init_func.value()}; + ADT_RETURN_IF_ERR(interpreter->InterpretCall(f, args)); + } else { + ADT_CHECK(args.size() == 0) << adt::errors::TypeError{ + std::string() + self.class_attrs->class_name + + "() takes no arguments"}; + } + return instance; + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/class_ops.h b/paddle/ap/include/axpr/class_ops.h new file mode 100644 index 00000000000000..6efac2208e4680 --- /dev/null +++ b/paddle/ap/include/axpr/class_ops.h @@ -0,0 +1,36 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/builtin_func_type.h" +#include "paddle/ap/include/axpr/class_attrs.h" + +namespace ap::axpr { + +template +class ClassOps { + public: + virtual ~ClassOps() = default; + + virtual const ClassAttrsImpl* class_attrs() const; + virtual adt::Result Equals(const ValueT& lhs_val, + const ValueT& rhs_val) const; + + protected: + ClassOps() = default; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/closure.h b/paddle/ap/include/axpr/closure.h new file mode 100644 index 00000000000000..21a35a91bafdb7 --- /dev/null +++ b/paddle/ap/include/axpr/closure.h @@ -0,0 +1,49 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/atomic.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +class Environment; + +template +struct ClosureImpl { + Lambda lambda; + std::shared_ptr> environment; + + bool operator==(const ClosureImpl& other) const { + return other.lambda == this->lambda && + other.environment == this->environment; + } +}; + +template +ADT_DEFINE_RC(Closure, const ClosureImpl); + +template +struct TypeImpl> : public std::monostate { + using value_type = Closure; + + const char* Name() const { return "closure"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/closure_method_class.h b/paddle/ap/include/axpr/closure_method_class.h new file mode 100644 index 00000000000000..40b164869d7db5 --- /dev/null +++ b/paddle/ap/include/axpr/closure_method_class.h @@ -0,0 +1,55 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/closure.h" +#include "paddle/ap/include/axpr/const_global_environment.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/mutable_global_environment.h" +#include "paddle/ap/include/axpr/serializable_value.h" + +namespace ap::axpr { + +template +struct ClosureMethodClass { + using This = ClosureMethodClass; + using Self = Closure; + + adt::Result GetAttr(const Self& self, const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, TryGetImpl(attr_name_val)); + if (attr_name == "__function__") { + return ToFunction(self); + } + return adt::errors::AttributeError{std::string() + + "closure object has not attribute '" + + attr_name + "'."}; + } + + adt::Result ToFunction(const Self& self) { + const auto& global_frame = + self->environment->RecursivelyGetConstGlobalFrame(); + return Function{self->lambda, global_frame}; + } +}; + +template +struct MethodClassImpl> + : public ClosureMethodClass {}; + +template +struct MethodClassImpl>> {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/const_global_environment.h b/paddle/ap/include/axpr/const_global_environment.h new file mode 100644 index 00000000000000..3b26c0e9717f23 --- /dev/null +++ b/paddle/ap/include/axpr/const_global_environment.h @@ -0,0 +1,67 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/environment.h" +#include "paddle/ap/include/axpr/frame.h" +#include "paddle/ap/include/axpr/serializable_value.h" + +namespace ap::axpr { + +template +class ConstGlobalEnvironment : public Environment { + public: + ConstGlobalEnvironment(const std::shared_ptr>& parent, + const Frame& frame) + : parent_(parent), frame_(frame) {} + + adt::Result Get(const std::string& var) const override { + ADT_LET_CONST_REF(frame_ptr, frame_.Get()); + const auto& res = frame_ptr->OptGet(var); + if (res.has_value()) { + return res.value().template CastTo(); + } + if (parent_ == nullptr) { + return NameError{std::string("name '") + var + "' is not defined."}; + } + return parent_->Get(var); + } + + adt::Result Set(const std::string& var, const ValueT& val) override { + return adt::errors::RuntimeError{"const global environment is immutable."}; + } + + std::optional> GetConstGlobalFrame() const override { + return frame_; + } + + std::optional> RecursivelyGetConstGlobalFrame() + const override { + return frame_; + } + + const Frame& frame() const { return frame_; } + + private: + ConstGlobalEnvironment(const ConstGlobalEnvironment&) = delete; + ConstGlobalEnvironment(ConstGlobalEnvironment&&) = delete; + + std::shared_ptr> parent_; + Frame frame_; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/constants.h b/paddle/ap/include/axpr/constants.h new file mode 100644 index 00000000000000..47ede616dc58f0 --- /dev/null +++ b/paddle/ap/include/axpr/constants.h @@ -0,0 +1,17 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/builtin_symbol.h" diff --git a/paddle/ap/include/axpr/continuation.h b/paddle/ap/include/axpr/continuation.h new file mode 100644 index 00000000000000..e7f123bb9bb88a --- /dev/null +++ b/paddle/ap/include/axpr/continuation.h @@ -0,0 +1,46 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/atomic.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +struct ContinuationImpl { + Lambda lambda; + std::shared_ptr> environment; + + bool operator==(const ContinuationImpl& other) const { + return other.lambda == this->lambda && + other.environment == this->environment; + } +}; + +template +ADT_DEFINE_RC(Continuation, ContinuationImpl); + +template +struct TypeImpl> : public std::monostate { + using value_type = Continuation; + + const char* Name() const { return "__builtin_continuation_class__"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/continuation_method_class.h b/paddle/ap/include/axpr/continuation_method_class.h new file mode 100644 index 00000000000000..bd09f6830287db --- /dev/null +++ b/paddle/ap/include/axpr/continuation_method_class.h @@ -0,0 +1,36 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/continuation.h" +#include "paddle/ap/include/axpr/method_class.h" + +namespace ap::axpr { + +template +struct ContinuationMethodClass { + using This = ContinuationMethodClass; + using Self = Continuation; +}; + +template +struct MethodClassImpl> + : public ContinuationMethodClass {}; + +template +struct MethodClassImpl>> {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/core_expr.h b/paddle/ap/include/axpr/core_expr.h new file mode 100644 index 00000000000000..e7e0d88a10c027 --- /dev/null +++ b/paddle/ap/include/axpr/core_expr.h @@ -0,0 +1,177 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/ap/include/axpr/atomic.h" +#include "paddle/ap/include/axpr/constants.h" + +namespace ap::axpr { + +using SymbolImpl = std::variant, builtin_symbol::Symbol>; + +struct Symbol : public SymbolImpl { + using SymbolImpl::SymbolImpl; + ADT_DEFINE_VARIANT_METHODS(SymbolImpl); + + std::size_t GetHashValue() const { + std::size_t hash_value = Match( + [&](const tVar& var) { + return std::hash()(var.value()); + }, + [&](const builtin_symbol::Symbol& symbol) { + return symbol.GetHashValue(); + }); + return adt::hash_combine(hash_value, this->index()); + } + + std::string Name() const { + return Match( + [](const tVar& var) -> std::string { return var.value(); }, + [](const builtin_symbol::Symbol& symbol) -> std::string { + return symbol.Name(); + }); + } +}; + +struct CoreExpr; + +template <> +struct ExprSymbolTrait { + using symbol_type = Symbol; +}; + +// (outer_func (inner_func [args])) +template +struct ComposedCallImpl { + T outer_func; + T inner_func; + std::vector args; + + bool operator==(const ComposedCallImpl& other) const { + return (this->outer_func == other.outer_func) && + (this->inner_func == other.inner_func) && (this->args == other.args); + } +}; + +template +ADT_DEFINE_RC(ComposedCall, const ComposedCallImpl); + +template +using ComposedCallAtomic = ComposedCall>; + +// core expr +// expr := aexpr | (aexpr (aexpr [aexpr])) +using CoreExprBase = + std::variant, ComposedCallAtomic>; + +struct CoreExpr : public CoreExprBase { + using CoreExprBase::CoreExprBase; + ADT_DEFINE_VARIANT_METHODS(CoreExprBase); + + std::string ToSExpression() const; +}; + +size_t GetHashValue(const CoreExpr& core_expr); +size_t GetHashValue(const ComposedCallAtomic& composed_call); +size_t GetHashValue(const Atomic& atomic); +size_t GetHashValue(const Lambda& lambda); + +inline size_t GetHashValue(const CoreExpr& core_expr) { + size_t hash_value = + core_expr.Match([&](const auto& impl) { return GetHashValue(impl); }); + return adt::hash_combine(hash_value, core_expr.index()); +} + +inline size_t GetHashValue(const ComposedCallAtomic& composed_call) { + size_t ret = 0; + ret = adt::hash_combine(ret, GetHashValue(composed_call->outer_func)); + ret = adt::hash_combine(ret, GetHashValue(composed_call->inner_func)); + for (const auto& arg : composed_call->args) { + ret = adt::hash_combine(ret, GetHashValue(arg)); + } + return ret; +} + +inline size_t GetHashValue(const Atomic& atomic) { + size_t ret = atomic.Match( + [](const adt::Nothing) -> size_t { return 0; }, + [](const Symbol& symbol) -> size_t { return symbol.GetHashValue(); }, + [](const bool val) -> size_t { return val; }, + [](const int64_t val) -> size_t { return val; }, + [](const double val) -> size_t { + return *reinterpret_cast(&val); + }, + [](const std::string& val) -> size_t { + return std::hash()(val); + }, + [](const Lambda& lambda) -> size_t { + return GetHashValue(lambda); + }); + return adt::hash_combine(ret, atomic.index()); +} + +inline size_t GetHashValue(const Lambda& lambda) { + size_t ret = 0; + for (const auto& arg : lambda->args) { + ret = adt::hash_combine(ret, std::hash()(arg.value())); + } + return adt::hash_combine(ret, GetHashValue(lambda->body)); +} + +} // namespace ap::axpr + +namespace std { + +inline std::ostream& operator<<(std::ostream& os, + const ap::axpr::CoreExpr& core_expr) { + return os << core_expr.ToSExpression(); +} + +template <> +struct hash { + size_t operator()(const ap::axpr::CoreExpr& core_expr) const { + return GetHashValue(core_expr); + } +}; + +template <> +struct hash> { + size_t operator()( + const ap::axpr::Lambda& core_expr) const { + return GetHashValue(core_expr); + } +}; + +template <> +struct hash> { + size_t operator()( + const ap::axpr::Atomic& core_expr) const { + return GetHashValue(core_expr); + } +}; + +template <> +struct hash> { + size_t operator()( + const ap::axpr::ComposedCallAtomic& core_expr) const { + return GetHashValue(core_expr); + } +}; + +} // namespace std diff --git a/paddle/ap/include/axpr/core_expr_builder.h b/paddle/ap/include/axpr/core_expr_builder.h new file mode 100644 index 00000000000000..d13996c8dd9c6b --- /dev/null +++ b/paddle/ap/include/axpr/core_expr_builder.h @@ -0,0 +1,38 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/atomic_builder.h" +#include "paddle/ap/include/axpr/core_expr.h" + +namespace ap::axpr { + +class CoreExprBuilder : public AtomicExprBuilder { + public: + CoreExprBuilder() {} + CoreExprBuilder(const CoreExprBuilder&) = delete; + CoreExprBuilder(CoreExprBuilder&&) = delete; + + ap::axpr::ComposedCallAtomic ComposedCallAtomic( + const Atomic& outer_func, + const Atomic& inner_func, + const std::vector>& args) { + return ap::axpr::ComposedCallAtomic{outer_func, inner_func, args}; + } + + private: +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/cps_interpreter.h b/paddle/ap/include/axpr/cps_interpreter.h new file mode 100644 index 00000000000000..a458a766471cb3 --- /dev/null +++ b/paddle/ap/include/axpr/cps_interpreter.h @@ -0,0 +1,649 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/bool_helper.h" +#include "paddle/ap/include/axpr/builtin_classes.h" +#include "paddle/ap/include/axpr/builtin_environment.h" +#include "paddle/ap/include/axpr/builtin_frame_util.h" +#include "paddle/ap/include/axpr/builtin_functions.h" +#include "paddle/ap/include/axpr/call_environment.h" +#include "paddle/ap/include/axpr/const_global_environment.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/interpreter_base.h" +#include "paddle/ap/include/axpr/module_mgr_helper.h" +#include "paddle/ap/include/axpr/mutable_global_environment.h" +#include "paddle/ap/include/axpr/to_string.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/axpr/value_method_class.h" + +namespace ap::axpr { + +class CpsInterpreter : public InterpreterBase { + public: + using This = CpsInterpreter; + using Env = Environment; + explicit CpsInterpreter( + const AttrMap& builtin_frame_attr_map, + const std::weak_ptr& circlable_ref_list) + : builtin_env_(GetBuiltinEnvironment(builtin_frame_attr_map)), + circlable_ref_list_(circlable_ref_list) {} + CpsInterpreter(const CpsInterpreter&) = delete; + CpsInterpreter(CpsInterpreter&&) = delete; + + using Ok = adt::Result; + + const std::shared_ptr& builtin_env() const { return builtin_env_; } + + Result Interpret(const Lambda& lambda, + const std::vector& args) { + Function function{lambda, std::nullopt}; + return Interpret(function, args); + } + + Result Interpret(const axpr::Value& function, + const std::vector& args) { + return InterpretCall(function, args); + } + + Result InterpretCall( + const axpr::Value& func, const std::vector& args) override { + ComposedCallImpl composed_call{&BuiltinHalt, func, args}; + ADT_RETURN_IF_ERR(InterpretComposedCallUntilHalt(&composed_call)); + ADT_CHECK(IsHalt(composed_call.inner_func)) + << RuntimeError{"CpsInterpreter does not halt."}; + ADT_CHECK(composed_call.args.size() == 1) << RuntimeError{ + std::string() + "halt function takes 1 argument. but " + + std::to_string(composed_call.args.size()) + " were given."}; + return composed_call.args.at(0); + } + + Result InterpretModule( + const Frame& const_global_frame, + const Lambda& lambda) override { + std::optional>> env; + { + ADT_LET_CONST_REF(ref_lst, adt::WeakPtrLock(circlable_ref_list_)); + auto tmp_frame_object = std::make_shared>(); + auto tmp_frame = Frame::Make(ref_lst, tmp_frame_object); + const auto& mut_global_env = MakeMutableGlobalEnvironment( + builtin_env(), const_global_frame, tmp_frame); + env = mut_global_env; + } + ADT_CHECK(lambda->args.empty()); + ADT_RETURN_IF_ERR(env.value()->Set(kBuiltinReturn(), &BuiltinHalt)); + Continuation continuation{lambda, env.value()}; + const auto& ret = InterpretCall(continuation, {}); + return ret; + } + + protected: + Ok InterpretComposedCallUntilHalt( + ComposedCallImpl* composed_call) { + while (!IsHalt(composed_call->inner_func)) { + ADT_RETURN_IF_ERR(InterpretComposedCall(composed_call)); + } + return adt::Ok{}; + } + + Ok InterpretComposedCall(ComposedCallImpl* composed_call) { + using TypeT = typename TypeTrait::TypeT; + return composed_call->inner_func.Match( + [&](const TypeT& type) -> Ok { + return InterpretConstruct(type, composed_call); + }, + [&](const BuiltinFuncType& func) -> Ok { + return InterpretBuiltinFuncCall(func, composed_call); + }, + [&](const BuiltinHighOrderFuncType& func) -> Ok { + return InterpretBuiltinHighOrderFuncCall(func, composed_call); + }, + [&](const Method& method) -> Ok { + return method->func.Match( + [&](const BuiltinFuncType& func) { + return InterpretBuiltinMethodCall( + func, method->obj, composed_call); + }, + [&](const BuiltinHighOrderFuncType& func) { + return InterpretBuiltinHighOrderMethodCall( + func, method->obj, composed_call); + }, + [&](const auto&) { + return InterpretMethodCall(method, composed_call); + }); + }, + [&](const Closure& closure) -> Ok { + return InterpretClosureCall(composed_call->outer_func, + closure, + composed_call->args, + composed_call); + }, + [&](const Continuation& continuation) -> Ok { + return InterpretContinuation( + &BuiltinHalt, continuation, composed_call); + }, + [&](const Function& function) -> Ok { + ADT_LET_CONST_REF(closure, ConvertFunctionToClosure(function)); + return InterpretClosureCall(composed_call->outer_func, + closure, + composed_call->args, + composed_call); + }, + [&](const builtin_symbol::Symbol& symbol) -> Ok { + return InterpretBuiltinSymbolCall(symbol, composed_call); + }, + [&](const auto&) -> Ok { + const auto& call_func = + MethodClass::template GetBuiltinUnaryFunc< + builtin_symbol::Call>(composed_call->inner_func); + ADT_RETURN_IF_ERR(call_func.Match( + [&](const adt::Nothing&) -> Ok { + return adt::errors::TypeError{ + std::string("'") + + axpr::GetTypeName(composed_call->inner_func) + + "' object is not callable"}; + }, + [&](adt::Result (*unary_func)( + const axpr::Value&)) -> Ok { + ADT_LET_CONST_REF(func, unary_func(composed_call->inner_func)); + composed_call->inner_func = func; + return adt::Ok{}; + }, + [&](adt::Result (*unary_func)( + InterpreterBase*, const axpr::Value&)) -> Ok { + ADT_LET_CONST_REF(func, + unary_func(this, composed_call->inner_func)); + composed_call->inner_func = func; + return adt::Ok{}; + })); + return adt::Ok{}; + }); + } + + bool IsHalt(const axpr::Value& func) { + return func.Match( + [&](BuiltinFuncType f) { return f == &BuiltinHalt; }, + [&](const auto&) { return false; }); + } + + Result InterpretAtomic(const std::shared_ptr& env, + const Atomic& atomic) { + return atomic.Match( + [&](const Lambda& lambda) -> Result { + if (const auto& const_global_frame = env->GetConstGlobalFrame()) { + return Function{lambda, + const_global_frame.value()}; + } else { + return Closure{lambda, env}; + } + }, + [&](const Symbol& symbol) -> Result { + return symbol.Match( + [&](const tVar& var) -> Result { + ADT_LET_CONST_REF(val, env->Get(var.value())) + << adt::errors::NameError{std::string("var '") + + var.value() + + "' is not defined."}; + return val; + }, + [&](const builtin_symbol::Symbol& symbol) -> Result { + return symbol; + }); + }, + [&](adt::Nothing) -> Result { return adt::Nothing{}; }, + [&](bool c) -> Result { return c; }, + [&](int64_t c) -> Result { return c; }, + [&](double c) -> Result { return c; }, + [&](const std::string& val) -> Result { return val; }); + } + + Result InterpretAtomicAsContinuation( + const std::shared_ptr& env, const Atomic& atomic) { + return atomic.Match( + [&](const Lambda& lambda) -> Result { + return Continuation{lambda, env}; + }, + [&](const Symbol& symbol) -> Result { + return symbol.Match( + [&](const tVar& var) -> Result { + ADT_CHECK(var.value() == kBuiltinReturn()); + ADT_LET_CONST_REF(val, env->Get(var.value())) + << adt::errors::NotImplementedError{ + "no return continuation found."}; + return val; + }, + [&](const auto&) -> Result { + return adt::errors::NotImplementedError{ + "Invalid continuation."}; + }); + }, + [&](const auto&) -> Result { + return adt::errors::NotImplementedError{"Invalid continuation."}; + }); + } + + Ok InterpretBuiltinSymbolCall( + const builtin_symbol::Symbol& symbol, + ComposedCallImpl* ret_composed_call) { + return symbol.Match( + [&](const builtin_symbol::If&) -> Ok { + ADT_RETURN_IF_ERR(InterpretIf(ret_composed_call)); + return adt::Ok{}; + }, + [&](const builtin_symbol::Id&) -> Ok { + ret_composed_call->inner_func = &BuiltinIdentity; + return adt::Ok{}; + }, + [&](const builtin_symbol::List&) -> Ok { + ret_composed_call->inner_func = &BuiltinList; + return adt::Ok{}; + }, + [&](const builtin_symbol::Op& op) -> Ok { + return op.Match([&](auto impl) -> Ok { + using BuiltinSymbol = decltype(impl); + if constexpr (BuiltinSymbol::num_operands == 1) { + return this + ->template InterpretBuiltinUnarySymbolCall( + ret_composed_call); + } else if constexpr (BuiltinSymbol::num_operands == 2) { + return this + ->template InterpretBuiltinBinarySymbolCall( + ret_composed_call); + } else { + static_assert(true, "NotImplemented"); + return NotImplementedError{"NotImplemented."}; + } + }); + }); + } + + Ok InterpretIf(ComposedCallImpl* composed_call) { + const auto args = composed_call->args; + ADT_CHECK(args.size() == 3) + << TypeError{std::string("`if` takes 3 arguments, but ") + + std::to_string(args.size()) + "were given."}; + const auto& cond = args.at(0); + ADT_LET_CONST_REF(select_true_branch, BoolHelper{}.ConvertToBool(cond)); + ADT_LET_CONST_REF(true_closure, + args.at(1).template TryGet>()); + ADT_LET_CONST_REF(false_closure, + args.at(2).template TryGet>()); + Closure closure{select_true_branch ? true_closure + : false_closure}; + composed_call->inner_func = closure; + composed_call->args = std::vector{}; + return adt::Ok{}; + } + + template + Ok InterpretBuiltinUnarySymbolCall( + ComposedCallImpl* ret_composed_call) { + ADT_CHECK(ret_composed_call->args.size() == 1) << TypeError{ + std::string() + "'" + BuiltinSymbol::Name() + + "' takes 1 argument. but " + + std::to_string(ret_composed_call->args.size()) + " were given."}; + const auto& operand = ret_composed_call->args.at(0); + std::optional opt_ret; + const auto& func = + MethodClass::template GetBuiltinUnaryFunc( + operand); + ADT_RETURN_IF_ERR(func.Match( + [&](const adt::Nothing&) -> Ok { + return TypeError{std::string() + "unsupported operand type for " + + GetBuiltinSymbolDebugString() + + ": '" + axpr::GetTypeName(operand) + "'"}; + }, + [&](adt::Result (*unary_func)(const axpr::Value&)) -> Ok { + ADT_LET_CONST_REF(ret, unary_func(operand)); + opt_ret = ret; + return adt::Ok{}; + }, + [&](adt::Result (*unary_func)( + InterpreterBase*, const axpr::Value&)) -> Ok { + ADT_LET_CONST_REF(ret, unary_func(this, operand)); + opt_ret = ret; + return adt::Ok{}; + })); + ADT_CHECK(opt_ret.has_value()); + ret_composed_call->args = {opt_ret.value()}; + ret_composed_call->inner_func = ret_composed_call->outer_func; + ret_composed_call->outer_func = &BuiltinHalt; + return adt::Ok{}; + } + + template + Ok InterpretConstruct(const TypeT& type, + ComposedCallImpl* ret_composed_call) { + const auto& func = MethodClass::template GetBuiltinUnaryFunc< + builtin_symbol::Call>(axpr::Value{type}); + ADT_RETURN_IF_ERR(func.Match( + [&](const adt::Nothing&) -> Ok { + return adt::errors::TypeError{ + std::string() + "no constructor for type '" + type.Name() + "'"}; + }, + [&](adt::Result (*unary_func)(const axpr::Value&)) -> Ok { + ADT_LET_CONST_REF(constructor, unary_func(axpr::Value{type})); + ret_composed_call->inner_func = constructor; + return adt::Ok{}; + }, + [&](adt::Result (*unary_func)( + InterpreterBase*, const axpr::Value&)) -> Ok { + ADT_LET_CONST_REF(constructor, unary_func(this, axpr::Value{type})); + ret_composed_call->inner_func = constructor; + return adt::Ok{}; + })); + return adt::Ok{}; + } + + template + Ok InterpretBuiltinBinarySymbolCall( + ComposedCallImpl* ret_composed_call) { + ADT_CHECK(ret_composed_call->args.size() == 2) << TypeError{ + std::string() + "'" + BuiltinSymbol::Name() + + "' takes 2 argument. but " + + std::to_string(ret_composed_call->args.size()) + " were given."}; + const auto& lhs = ret_composed_call->args.at(0); + const auto& func = + MethodClass::template GetBuiltinBinaryFunc( + lhs); + std::optional opt_ret; + ADT_RETURN_IF_ERR(func.Match( + [&](const adt::Nothing&) -> Ok { + return TypeError{std::string() + "unsupported operand type for " + + GetBuiltinSymbolDebugString() + + ": '" + axpr::GetTypeName(lhs) + "'"}; + }, + [&](adt::Result (*binary_func)(const axpr::Value&, + const axpr::Value&)) -> Ok { + const auto& rhs = ret_composed_call->args.at(1); + ADT_LET_CONST_REF(ret, binary_func(lhs, rhs)); + opt_ret = ret; + return adt::Ok{}; + }, + [&](adt::Result (*binary_func)( + InterpreterBase*, + const axpr::Value&, + const axpr::Value&)) -> Ok { + const auto& rhs = ret_composed_call->args.at(1); + ADT_LET_CONST_REF(ret, binary_func(this, lhs, rhs)); + opt_ret = ret; + return adt::Ok{}; + })); + ADT_CHECK(opt_ret.has_value()); + ret_composed_call->args = {opt_ret.value()}; + ret_composed_call->inner_func = ret_composed_call->outer_func; + ret_composed_call->outer_func = &BuiltinHalt; + return adt::Ok{}; + } + + Ok InterpretClosureCall(const axpr::Value& continuation, + const Closure& closure, + const std::vector& args, + ComposedCallImpl* ret_composed_call) { + ADT_LET_CONST_REF(new_env, MakeCallEnvironment(closure->environment)); + ADT_RETURN_IF_ERR(new_env->Set(kBuiltinReturn(), continuation)); + return InterpretLambdaCall( + new_env, continuation, closure->lambda, args, ret_composed_call); + } + + Ok InterpretLambdaCall( + const std::shared_ptr& env, + const axpr::Value& outer_func, + const Lambda& lambda, + const std::vector& args, + ComposedCallImpl* ret_composed_call) override { + auto PassPackedArgs = [&](const std::optional& self, + const axpr::Value& packed) -> Ok { + ADT_LET_CONST_REF(packed_args, + packed.template TryGet>()); + const auto& [pos_args, kwargs] = *packed_args; + int lambda_arg_idx = (self.has_value() ? 1 : 0); + ADT_CHECK(lambda_arg_idx + pos_args->size() <= lambda->args.size()) + << TypeError{std::string("() takes ") + + std::to_string(lambda->args.size()) + + "at most positional arguments but " + + std::to_string(pos_args->size()) + " was given"}; + std::set passed_args; + if (self.has_value()) { + const auto& self_name = lambda->args.at(0).value(); + passed_args.insert(self_name); + ADT_RETURN_IF_ERR(env->Set(self_name, self.value())); + } + for (int pos_arg_idx = 0; pos_arg_idx < pos_args->size(); + ++pos_arg_idx, ++lambda_arg_idx) { + const auto& arg_name = lambda->args.at(lambda_arg_idx).value(); + passed_args.insert(arg_name); + ADT_RETURN_IF_ERR(env->Set(arg_name, pos_args->at(pos_arg_idx))); + } + for (; lambda_arg_idx < lambda->args.size(); ++lambda_arg_idx) { + const auto& arg_name = lambda->args.at(lambda_arg_idx).value(); + if (passed_args.count(arg_name) > 0) { + return adt::errors::TypeError{ + std::string() + "() got multiple values for argument '" + + arg_name + "'"}; + } + passed_args.insert(arg_name); + ADT_LET_CONST_REF(kwarg, kwargs->Get(arg_name)) + << adt::errors::TypeError{ + std::string() + + "() missing 1 required positional argument: '" + + arg_name + "'"}; + ADT_RETURN_IF_ERR(env->Set(arg_name, kwarg)); + } + for (const auto& [key, _] : kwargs->storage) { + ADT_CHECK(passed_args.count(key) > 0) << adt::errors::TypeError{ + std::string() + "() got an unexpected keyword argument '" + + key + "'"}; + } + return adt::Ok{}; + }; + if (args.size() == 1 && + args.at(0).template Has>()) { + ADT_RETURN_IF_ERR( + PassPackedArgs(/*self=*/std::nullopt, /*packed=*/args.at(0))); + } else if (args.size() == 2 && + args.at(1).template Has>()) { + ADT_RETURN_IF_ERR( + PassPackedArgs(/*self=*/args.at(0), /*packed=*/args.at(1))); + } else { + if (args.size() > lambda->args.size()) { + return adt::errors::TypeError{ + std::string("() takes ") + + std::to_string(lambda->args.size()) + " positional arguments but " + + std::to_string(args.size()) + " was given"}; + } + if (args.size() < lambda->args.size()) { + if (args.size() + 1 == lambda->args.size()) { + return adt::errors::TypeError{ + "() missing 1 required positional argument: '" + + lambda->args.at(args.size()).value() + "'"}; + } else { + std::ostringstream ss; + ss << "() missing " << (lambda->args.size() - args.size()) + << " required positional arguments: "; + ss << "'" << lambda->args.at(args.size()).value() << "'"; + for (int i = args.size() + 1; i < lambda->args.size(); ++i) { + ss << "and '" << lambda->args.at(i).value() << "'"; + } + return adt::errors::TypeError{ss.str()}; + } + } + for (int i = 0; i < args.size(); ++i) { + const auto& arg_name = lambda->args.at(i).value(); + ADT_RETURN_IF_ERR(env->Set(arg_name, args.at(i))); + } + } + return InterpretLambdaBody( + env, outer_func, lambda->body, ret_composed_call); + } + + Ok InterpretContinuation(const axpr::Value& outer_func, + const Continuation& continuation, + ComposedCallImpl* composed_call) { + const auto& env = continuation->environment; + const auto& lambda = continuation->lambda; + if (lambda->args.size() > 0) { + ADT_CHECK(lambda->args.size() == 1); + ADT_CHECK(composed_call->args.size() == 1); + ADT_RETURN_IF_ERR( + env->Set(lambda->args.at(0).value(), composed_call->args.at(0))); + } else { + // Do nothing. + } + return InterpretLambdaBody(env, outer_func, lambda->body, composed_call); + } + + Ok InterpretLambdaBody(const std::shared_ptr& env, + const axpr::Value& outer_func, + const CoreExpr& lambda_body, + ComposedCallImpl* ret_composed_call) { + return lambda_body.Match( + [&](const Atomic& atomic) -> Ok { + ADT_LET_CONST_REF(val, InterpretAtomic(env, atomic)); + ret_composed_call->inner_func = outer_func; + ret_composed_call->outer_func = &BuiltinHalt; + ret_composed_call->args = {val}; + return adt::Ok{}; + }, + [&](const ComposedCallAtomic& core_expr) -> Ok { + return InterpretLambdaBodyComposedCallAtomic( + env, core_expr, ret_composed_call); + }); + } + + Ok InterpretLambdaBodyComposedCallAtomic( + const std::shared_ptr& env, + const ComposedCallAtomic& core_expr, + ComposedCallImpl* ret_composed_call) { + ADT_LET_CONST_REF( + continuation, + InterpretAtomicAsContinuation(env, core_expr->outer_func)); + ADT_LET_CONST_REF(new_inner_func, + InterpretAtomic(env, core_expr->inner_func)); + std::vector args; + args.reserve(core_expr->args.size()); + for (const auto& arg_expr : core_expr->args) { + ADT_LET_CONST_REF(arg, InterpretAtomic(env, arg_expr)); + args.emplace_back(arg); + } + ret_composed_call->outer_func = continuation; + ret_composed_call->inner_func = new_inner_func; + ret_composed_call->args = std::move(args); + return adt::Ok{}; + } + + Ok InterpretBuiltinFuncCall(const BuiltinFuncType& func, + ComposedCallImpl* composed_call) { + return InterpretBuiltinMethodCall( + func, axpr::Value{adt::Nothing{}}, composed_call); + } + + Ok InterpretBuiltinHighOrderFuncCall( + const BuiltinHighOrderFuncType& func, + ComposedCallImpl* composed_call) { + return InterpretBuiltinHighOrderMethodCall( + func, axpr::Value{adt::Nothing{}}, composed_call); + } + + Ok InterpretBuiltinMethodCall(const BuiltinFuncType& func, + const axpr::Value& obj, + ComposedCallImpl* composed_call) { + ADT_LET_CONST_REF(inner_ret, func(obj, composed_call->args)); + composed_call->inner_func = composed_call->outer_func; + composed_call->outer_func = &BuiltinHalt; + composed_call->args = {inner_ret}; + return adt::Ok{}; + } + + Ok InterpretBuiltinHighOrderMethodCall( + const BuiltinHighOrderFuncType& func, + const axpr::Value& obj, + ComposedCallImpl* composed_call) { + ADT_LET_CONST_REF(inner_ret, func(this, obj, composed_call->args)); + composed_call->inner_func = composed_call->outer_func; + composed_call->outer_func = &BuiltinHalt; + composed_call->args = {inner_ret}; + return adt::Ok{}; + } + + Ok InterpretMethodCall(const Method& method, + ComposedCallImpl* composed_call) { + std::vector new_args; + new_args.reserve(composed_call->args.size() + 1); + new_args.emplace_back(method->obj); + for (const auto& arg : composed_call->args) { + new_args.emplace_back(arg); + } + composed_call->inner_func = method->func; + composed_call->args = std::move(new_args); + return adt::Ok{}; + } + + std::weak_ptr circlable_ref_list() + const override { + return circlable_ref_list_; + } + + std::shared_ptr builtin_env_; + std::weak_ptr circlable_ref_list_; + + private: + Result> ConvertFunctionToClosure( + const Function& function) { + const auto& global_frame = function->global_frame; + if (global_frame.has_value()) { + const auto& const_env = + MakeConstGlobalEnvironment(builtin_env(), global_frame.value()); + return Closure{function->lambda, const_env}; + } else { + return Closure{function->lambda, builtin_env()}; + } + } + + static std::shared_ptr> GetBuiltinEnvironment( + const AttrMap& builtin_frame_attr_map) { + return std::make_shared>( + builtin_frame_attr_map); + } + + static std::shared_ptr> MakeConstGlobalEnvironment( + const std::shared_ptr>& parent, + const Frame& frame) { + return std::make_shared>(parent, frame); + } + + static std::shared_ptr> MakeMutableGlobalEnvironment( + const std::shared_ptr>& parent, + const Frame& const_frame, + const Frame& temp_frame) { + return std::make_shared>( + parent, const_frame, temp_frame); + } + + adt::Result>> MakeCallEnvironment( + const std::shared_ptr>& parent) { + auto builtin_obj = std::make_shared>(); + ADT_LET_CONST_REF(ref_lst, adt::WeakPtrLock(circlable_ref_list())); + const auto& frame = Frame::Make(ref_lst, builtin_obj); + return std::make_shared>(parent, frame); + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/data_type.h b/paddle/ap/include/axpr/data_type.h new file mode 100644 index 00000000000000..9f66c1f76ce418 --- /dev/null +++ b/paddle/ap/include/axpr/data_type.h @@ -0,0 +1,141 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/type.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/float8_e4m3fn.h" +#include "paddle/phi/common/float8_e5m2.h" +#include "paddle/phi/common/pstring.h" + +namespace ap::axpr { + +using complex64 = ::phi::dtype::complex; +using complex128 = ::phi::dtype::complex; +using float16 = ::phi::dtype::float16; +using bfloat16 = ::phi::dtype::bfloat16; +using float8_e4m3fn = ::phi::dtype::float8_e4m3fn; +using float8_e5m2 = ::phi::dtype::float8_e5m2; +using pstring = ::phi::dtype::pstring; + +#define PEXPR_FOR_EACH_ARITHMETIC_OP_SUPPORTED_TYPE(_) \ + _(bool) \ + _(float) \ + _(double) \ + _(int8_t) \ + _(uint8_t) \ + _(int16_t) \ + _(uint16_t) \ + _(int32_t) \ + _(uint32_t) \ + _(int64_t) \ + _(uint64_t) + +namespace detail { + +template +struct IsArithmeticOpSupportedHelper { + static constexpr bool value = false; +}; + +#define SPECIALIZE_IS_ARITHMETIC_OP_SUPPORTED(cpp_type) \ + template <> \ + struct IsArithmeticOpSupportedHelper { \ + static constexpr bool value = true; \ + }; + +PEXPR_FOR_EACH_ARITHMETIC_OP_SUPPORTED_TYPE( + SPECIALIZE_IS_ARITHMETIC_OP_SUPPORTED) + +#undef SPECIALIZE_IS_ARITHMETIC_OP_SUPPORTED + +} // namespace detail + +template +constexpr bool IsArithmeticOpSupported() { + return detail::IsArithmeticOpSupportedHelper::value; +} + +template +struct GetDataTypeNameHelper; + +#define SPECIALIZE_GET_CPP_TYPE_NAME(cpp_type, enum_type) \ + template <> \ + struct GetDataTypeNameHelper { \ + static const char* Name() { return #cpp_type; } \ + }; \ + template <> \ + struct GetDataTypeNameHelper { \ + static const char* Name() { return "const_" #cpp_type; } \ + }; +PD_FOR_EACH_DATA_TYPE(SPECIALIZE_GET_CPP_TYPE_NAME); +#undef SPECIALIZE_GET_CPP_TYPE_NAME +template <> +struct GetDataTypeNameHelper { + static const char* Name() { return "void"; } +}; + +template <> +struct GetDataTypeNameHelper { + static const char* Name() { return "const_void"; } +}; + +template +struct CppDataType : public std::monostate { + using std::monostate::monostate; + using type = T; + const char* Name() const { return GetDataTypeNameHelper::Name(); } +}; + +// clang-format off +using DataTypeImpl = std::variant< +#define MAKE_ARITHMETIC_TYPE_ALTERNATIVE(cpp_type, enum_type) \ + CppDataType, + PD_FOR_EACH_DATA_TYPE(MAKE_ARITHMETIC_TYPE_ALTERNATIVE) +#undef MAKE_ARITHMETIC_TYPE_ALTERNATIVE + CppDataType>; +// clang-format on + +struct DataType : public DataTypeImpl { + using DataTypeImpl::DataTypeImpl; + ADT_DEFINE_VARIANT_METHODS(DataTypeImpl); + + const char* Name() const { + return Match([](const auto& impl) { return impl.Name(); }); + } + + std::size_t GetHashValue() const { return index(); } +}; + +template <> +struct TypeImpl : public std::monostate { + using value_type = DataType; + + const char* Name() const { return "DataType"; } +}; + +} // namespace ap::axpr + +namespace std { + +template <> +struct hash { + std::size_t operator()(ap::axpr::DataType dtype) const { + return dtype.GetHashValue(); + } +}; + +} // namespace std diff --git a/paddle/ap/include/axpr/data_type_method_class.h b/paddle/ap/include/axpr/data_type_method_class.h new file mode 100644 index 00000000000000..8afe0aa12685be --- /dev/null +++ b/paddle/ap/include/axpr/data_type_method_class.h @@ -0,0 +1,138 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/data_type.h" +#include "paddle/ap/include/axpr/int_data_type.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/pointer_type_util.h" + +namespace ap::axpr { + +template +struct DataTypeMethodClass { + using This = DataTypeMethodClass; + using Self = DataType; + + adt::Result ToString(const Self& data_type) { + return std::string("DataType.") + data_type.Name(); + } + + adt::Result Hash(const Self& data_type) { + int64_t hash_value = std::hash()("DataType"); + hash_value = adt::hash_combine(hash_value, data_type.index()); + return hash_value; + } + + adt::Result GetAttr(const Self& self, const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, attr_name_val.template CastTo()); + if (attr_name == "const_pointer_type") { + return GetConstPointerType(self); + } + if (attr_name == "mutable_pointer_type") { + return GetMutablePointerType(self); + } + return adt::errors::AttributeError{ + std::string() + "DataType has no attribute '" + attr_name + "'"}; + } + + adt::Result GetConstPointerType(const Self& self) { + return GetConstPointerTypeFromDataType(self); + } + + adt::Result GetMutablePointerType(const Self& self) { + return GetMutablePointerTypeFromDataType(self); + } + + template + static BuiltinBinaryFunc GetBuiltinBinaryFunc() { + if constexpr (std::is_same_v) { + return &This::EQ; + } else if constexpr (std::is_same_v) { + return &This::NE; + } else { + std::nullopt; + } + } + + static Result EQ(const ValueT& lhs_val, const ValueT& rhs_val) { + ADT_LET_CONST_REF(lhs, lhs_val.template TryGet()); + ADT_LET_CONST_REF(rhs, rhs_val.template TryGet()); + const auto& pattern_match = + ::common::Overloaded{[](auto lhs, auto rhs) -> ValueT { + return std::is_same_v; + }}; + return std::visit(pattern_match, lhs.variant(), rhs.variant()); + } + + static Result NE(const ValueT& lhs_val, const ValueT& rhs_val) { + ADT_LET_CONST_REF(lhs, lhs_val.template TryGet()); + ADT_LET_CONST_REF(rhs, rhs_val.template TryGet()); + const auto& pattern_match = + ::common::Overloaded{[](auto lhs, auto rhs) -> ValueT { + return !std::is_same_v; + }}; + return std::visit(pattern_match, lhs.variant(), rhs.variant()); + } +}; + +template +struct MethodClassImpl : public DataTypeMethodClass { +}; + +template +struct TypeImplDataTypeMethodClass { + using This = TypeImplDataTypeMethodClass; + using Self = TypeImpl; + + adt::Result GetAttr(const Self&, const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, TryGetImpl(attr_name_val)); + static const std::unordered_map map{ +#define MAKE_CPP_TYPE_CASE(cpp_type, enum_type) \ + {axpr::CppDataType{}.Name(), DataType{CppDataType{}}}, \ + {axpr::CppDataType{}.Name(), \ + DataType{ \ + CppDataType{}}}, // it's not a typo, DataType.const_int8 + // and DataType.int8 are treated + // identical. + + PD_FOR_EACH_DATA_TYPE(MAKE_CPP_TYPE_CASE) +#undef MAKE_CPP_TYPE_CASE + +#define MAKE_INT_CPP_TYPE_CASE(cpp_type) \ + {#cpp_type, DataType{CppDataType{}}}, \ + {"const_" #cpp_type, DataType{CppDataType{}}}, + + AP_FOR_EACH_INT_TYPE(MAKE_INT_CPP_TYPE_CASE) +#undef MAKE_INT_CPP_TYPE_CASE + {"void", DataType{CppDataType{}}}, + }; + const auto iter = map.find(attr_name); + if (iter != map.end()) { + return iter->second; + } + return adt::errors::AttributeError{ + std::string() + "class 'DataType' has no static attribute '" + + attr_name + "'."}; + } +}; + +template +struct MethodClassImpl> + : public TypeImplDataTypeMethodClass {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/data_type_util.h b/paddle/ap/include/axpr/data_type_util.h new file mode 100644 index 00000000000000..ecb98d797471b2 --- /dev/null +++ b/paddle/ap/include/axpr/data_type_util.h @@ -0,0 +1,56 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/data_type.h" +#include "paddle/phi/common/data_type.h" + +namespace ap::axpr { + +inline Result GetDataTypeFromPhiDataType(::phi::DataType data_type) { + static const std::unordered_map<::phi::DataType, DataType> map{ + {::phi::DataType::UNDEFINED, DataType{CppDataType{}}}, +#define MAKE_PHI_DATA_TYPE_TO_ARG_TYPE_CASE(cpp_type, enum_type) \ + {::phi::enum_type, DataType{CppDataType{}}}, + PD_FOR_EACH_DATA_TYPE(MAKE_PHI_DATA_TYPE_TO_ARG_TYPE_CASE) +#undef MAKE_PHI_DATA_TYPE_TO_ARG_TYPE_CASE + }; + const auto& iter = map.find(data_type); + if (iter == map.end()) { + return adt::errors::KeyError{[&] { + std::ostringstream ss{}; + ss << "Invalid phi data type. enum value: " << data_type; + return ss.str(); + }()}; + } + return iter->second; +} + +inline Result<::phi::DataType> GetPhiDataTypeFromDataType(DataType data_type) { + static const std::unordered_map map{ + {DataType{CppDataType{}}, ::phi::DataType::UNDEFINED}, +#define MAKE_PHI_DATA_TYPE_TO_ARG_TYPE_CASE(cpp_type, enum_type) \ + {DataType{CppDataType{}}, ::phi::enum_type}, + PD_FOR_EACH_DATA_TYPE(MAKE_PHI_DATA_TYPE_TO_ARG_TYPE_CASE) +#undef MAKE_PHI_DATA_TYPE_TO_ARG_TYPE_CASE + }; + const auto& iter = map.find(data_type); + if (iter == map.end()) { + return adt::errors::KeyError{"Invalid axpr data type."}; + } + return iter->second; +} + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/data_value.h b/paddle/ap/include/axpr/data_value.h new file mode 100644 index 00000000000000..edee506410dbfe --- /dev/null +++ b/paddle/ap/include/axpr/data_value.h @@ -0,0 +1,109 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/data_type.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +using DataValueImpl = std::variant< +#define MAKE_ARG_VALUE_ALTERNATIVE(cpp_type, enum_type) cpp_type, + PD_FOR_EACH_DATA_TYPE(MAKE_ARG_VALUE_ALTERNATIVE) adt::Undefined +#undef MAKE_ARG_VALUE_ALTERNATIVE + >; + +struct DataValue : public DataValueImpl { + using DataValueImpl::DataValueImpl; + ADT_DEFINE_VARIANT_METHODS(DataValueImpl); + + DataType GetType() const { + return Match( + [](auto impl) -> DataType { return CppDataType{}; }); + } + + Result StaticCastTo(const DataType& dst_type) const { + const auto& pattern_match = ::common::Overloaded{ + [&](auto arg_type_impl, auto cpp_value_impl) -> Result { + using DstT = typename decltype(arg_type_impl)::type; + return DataValueStaticCast(cpp_value_impl); + }}; + return std::visit(pattern_match, dst_type.variant(), this->variant()); + } + + Result ToString() const { + return Match([](const auto& impl) -> adt::Result { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return std::to_string(impl); + } else if constexpr (std::is_integral_v) { + return std::to_string(impl); + } else if constexpr (std::is_same_v) { + return std::to_string(impl); + } else if constexpr (std::is_same_v) { + return std::to_string(impl); + } else { + return adt::errors::NotImplementedError{"DataType NotImplemented."}; + } + }); + } + + Result GetHashValue() const { + using RetT = Result; + return Match([](const auto& impl) -> RetT { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return static_cast(std::hash()(impl)); + } else if constexpr (std::is_integral_v) { + return static_cast(std::hash()(impl)); + } else if constexpr (std::is_same_v) { + return static_cast(std::hash()(impl)); + } else if constexpr (std::is_same_v) { + return static_cast(std::hash()(impl)); + } else { + return adt::errors::NotImplementedError{"DataType NotImplemented."}; + } + }); + } + + private: + template + Result DataValueStaticCast(SrcT v) const { + if constexpr (std::is_same_v) { + return adt::errors::TypeError{ + "static_cast can not cast to 'undefined' type."}; + } else if constexpr (std::is_same_v) { + return adt::errors::TypeError{ + "static_cast can not cast to 'pstring' type."}; + } else if constexpr (std::is_same_v) { + return adt::errors::TypeError{ + "static_cast can not cast from 'undefined' type."}; + } else if constexpr (std::is_same_v) { + return adt::errors::TypeError{ + "static_cast can not cast from 'pstring' type."}; + } else { + return static_cast(v); + } + } +}; + +template <> +struct TypeImpl : public std::monostate { + using value_type = DataValue; + + const char* Name() const { return "DataValue"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/data_value_method_class.h b/paddle/ap/include/axpr/data_value_method_class.h new file mode 100644 index 00000000000000..95aae9771526d6 --- /dev/null +++ b/paddle/ap/include/axpr/data_value_method_class.h @@ -0,0 +1,379 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/builtin_func_type.h" +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/data_value.h" +#include "paddle/ap/include/axpr/data_value_util.h" +#include "paddle/ap/include/axpr/method.h" +#include "paddle/ap/include/axpr/method_class.h" + +namespace ap::axpr { + +namespace detail { + +template +Result ArgValueStaticCast(const Val& self, const std::vector& args) { + if (args.size() != 1) { + return TypeError{std::string() + "'DataValue.cast' take 1 arguments. but " + + std::to_string(args.size()) + " were given."}; + } + const Result& arg_value = self.template TryGet(); + ADT_RETURN_IF_ERR(arg_value); + const Result& arg_type = args.at(0).template TryGet(); + ADT_RETURN_IF_ERR(arg_type); + const auto& data_value = + arg_value.GetOkValue().StaticCastTo(arg_type.GetOkValue()); + ADT_RETURN_IF_ERR(data_value); + return data_value.GetOkValue(); +} + +template +adt::Result DataValueGetAttr(const DataValue& data_val, + const std::string& attr_name) { + if (attr_name == "cast") { + return ap::axpr::Method{data_val, &ArgValueStaticCast}; + } + return adt::errors::AttributeError{"'DataValue' object has no attribute '" + + attr_name + "'"}; +} + +} // namespace detail + +template +struct DataValueMethodClass { + using This = DataValueMethodClass; + using Self = DataValue; + + adt::Result ToString(const Self& self) { + ADT_LET_CONST_REF(str, self.ToString()); + return str; + } + + adt::Result Hash(const Self& self) { + ADT_LET_CONST_REF(hash_value, self.GetHashValue()); + return hash_value; + } + + template + static BuiltinUnaryFunc GetBuiltinUnaryFunc() { + if constexpr (ConvertBuiltinSymbolToArithmetic< + BuiltinUnarySymbol>::convertible) { + using ArithmeticOp = typename ConvertBuiltinSymbolToArithmetic< + BuiltinUnarySymbol>::arithmetic_op_type; + return &This::UnaryFunc; + } else { + return adt::Nothing{}; + } + } + + template + static BuiltinBinaryFunc GetBuiltinBinaryFunc() { + if constexpr (ConvertBuiltinSymbolToArithmetic< + BuiltinBinarySymbol>::convertible) { + using ArithmeticOp = typename ConvertBuiltinSymbolToArithmetic< + BuiltinBinarySymbol>::arithmetic_op_type; + return &This::template BinaryFunc; + } else if constexpr (std::is_same_v) { + return &This::GetAttr; + } else { + return adt::Nothing{}; + } + } + + static adt::Result GetAttr(const ValueT& obj_val, + const ValueT& attr_name_val) { + const auto& opt_obj = obj_val.template TryGet(); + ADT_RETURN_IF_ERR(opt_obj); + const auto& obj = opt_obj.GetOkValue(); + const auto& opt_attr_name = attr_name_val.template TryGet(); + ADT_RETURN_IF_ERR(opt_attr_name); + const auto& attr_name = opt_attr_name.GetOkValue(); + return detail::DataValueGetAttr(obj, attr_name); + } + + template + static adt::Result BinaryFunc(const ValueT& lhs_val, + const ValueT& rhs_val) { + const auto& opt_lhs = lhs_val.template TryGet(); + ADT_RETURN_IF_ERR(opt_lhs); + const auto& lhs = opt_lhs.GetOkValue(); + const auto& opt_rhs = rhs_val.template TryGet(); + ADT_RETURN_IF_ERR(opt_rhs); + const auto& rhs = opt_rhs.GetOkValue(); + const auto& ret = ArithmeticBinaryFunc(lhs, rhs); + ADT_RETURN_IF_ERR(ret); + return ret.GetOkValue(); + } + + template + static adt::Result UnaryFunc(const ValueT& val) { + const auto& opt_operand = val.template TryGet(); + ADT_RETURN_IF_ERR(opt_operand); + const auto& operand = opt_operand.GetOkValue(); + const auto& ret = ArithmeticUnaryFunc(operand); + ADT_RETURN_IF_ERR(ret); + return ret.GetOkValue(); + } +}; + +template +struct MethodClassImpl + : public DataValueMethodClass {}; + +namespace detail { + +template +adt::Result ConstructDataValue(const ValueT&, + const std::vector& args) { + if (args.size() != 1) { + return adt::errors::TypeError{ + std::string() + "constructor of 'DataValue' takes 1 arguments, but " + + std::to_string(args.size()) + " were given."}; + } + return args.at(0).Match( + [](bool c) -> adt::Result { return DataValue{c}; }, + [](int64_t c) -> adt::Result { return DataValue{c}; }, + [](const DataValue& c) -> adt::Result { return c; }, + [&](const auto& impl) -> adt::Result { + using T = std::decay_t; + return adt::errors::TypeError{ + std::string() + + "unsupported operand type for constructor of 'DataValue': '" + + axpr::GetTypeName(args.at(0)) + "'"}; + }); +} + +} // namespace detail + +template +struct MethodClassImpl> { + using This = MethodClassImpl>; + using Self = TypeImpl; + + adt::Result Call(const Self& self_val) { + return &detail::ConstructDataValue; + } + + adt::Result GetAttr(const Self&, const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, attr_name_val.template CastTo()); + static const std::map> map{ + {"float32", &This::MakeFloat32}, + {"float64", &This::MakeFloat64}, + // {"float16", &This::Make}, + // {"bfloat16", &This::Make}, + {"int64", &This::MakeInt64}, + {"int64_t", &This::MakeInt64}, + {"int32", &This::MakeInt32}, + {"int32_t", &This::MakeInt32}, + {"int16", &This::MakeInt16}, + {"int16_t", &This::MakeInt16}, + {"int8", &This::MakeInt8}, + {"int8_t", &This::MakeInt8}, + {"uint64", &This::MakeUint64}, + {"uint64_t", &This::MakeUint64}, + {"uint32", &This::MakeUint32}, + {"uint32_t", &This::MakeUint32}, + {"uint16", &This::MakeUint16}, + {"uint16_t", &This::MakeUint16}, + {"uint8", &This::MakeUint8}, + {"uint8_t", &This::MakeUint8}, + {"bool", &This::MakeBool}, + {"complex64", &This::MakeComplex64}, + {"complex128", &This::MakeComplex128}}; + const auto& iter = map.find(attr_name); + if (iter != map.end()) { + return ValueT{iter->second}; + } + return adt::errors::NotImplementedError{std::string() + "DataValue." + + attr_name + "() not implemented"}; + } + + template + static T StrToNum(const std::string& str) { + T x; + std::stringstream ss; + ss << str; + ss >> x; + return x; + } + + static adt::Result MakeFloat32(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "DataValue.float32() takes 1 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(str, args.at(0).template CastTo()); + return DataValue{StrToNum(str)}; + } + + static adt::Result MakeFloat64(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "DataValue.float64() takes 1 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(str, args.at(0).template CastTo()); + return DataValue{StrToNum(str)}; + } + + static adt::Result MakeInt64(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "DataValue.int64() takes 1 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(str, args.at(0).template CastTo()); + return DataValue{StrToNum(str)}; + } + + static adt::Result MakeInt32(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "DataValue.int32() takes 1 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(str, args.at(0).template CastTo()); + return DataValue{StrToNum(str)}; + } + + static adt::Result MakeInt16(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "DataValue.int16() takes 1 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(str, args.at(0).template CastTo()); + return DataValue{StrToNum(str)}; + } + + static adt::Result MakeInt8(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "DataValue.int8() takes 1 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(str, args.at(0).template CastTo()); + return DataValue{StrToNum(str)}; + } + + static adt::Result MakeUint64(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "DataValue.uint64() takes 1 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(str, args.at(0).template CastTo()); + return DataValue{StrToNum(str)}; + } + + static adt::Result MakeUint32(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "DataValue.uint32() takes 1 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(str, args.at(0).template CastTo()); + return DataValue{StrToNum(str)}; + } + + static adt::Result MakeUint16(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "DataValue.uint16() takes 1 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(str, args.at(0).template CastTo()); + return DataValue{StrToNum(str)}; + } + + static adt::Result MakeUint8(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "DataValue.uint8() takes 1 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(str, args.at(0).template CastTo()); + return DataValue{StrToNum(str)}; + } + + static adt::Result MakeBool(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "DataValue.bool() takes 1 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(str, args.at(0).template CastTo()); + return DataValue{StrToNum(str)}; + } + + static adt::Result MakeComplex64(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string() + "DataValue.complex64() takes 2 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(real_val, args.at(0).template CastTo()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of DataValue.complex64() should be a DataValue, " + "but a " + + axpr::GetTypeName(args.at(0)) + " were given"}; + ADT_LET_CONST_REF(real, real_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of DataValue.complex64() should be a float32, " + "but a " + + real_val.GetType().Name() + " were given"}; + ADT_LET_CONST_REF(imag_val, args.at(1).template CastTo()) + << adt::errors::TypeError{ + std::string() + + "the argument 2 of DataValue.complex64() should be a DataValue, " + "but a " + + axpr::GetTypeName(args.at(1)) + " were given"}; + ADT_LET_CONST_REF(imag, real_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + + "the argument 2 of DataValue.complex64() should be a float32, " + "but a " + + imag_val.GetType().Name() + " were given"}; + return DataValue{axpr::complex64(real, imag)}; + } + + static adt::Result MakeComplex128(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string() + "DataValue.complex128() takes 2 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(real_val, args.at(0).template CastTo()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of DataValue.complex128() should be a " + "DataValue, but a " + + axpr::GetTypeName(args.at(0)) + " were given"}; + ADT_LET_CONST_REF(real, real_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of DataValue.complex128() should be a float64, " + "but a " + + real_val.GetType().Name() + " were given"}; + ADT_LET_CONST_REF(imag_val, args.at(1).template CastTo()) + << adt::errors::TypeError{ + std::string() + + "the argument 2 of DataValue.complex128() should be a " + "DataValue, but a " + + axpr::GetTypeName(args.at(1)) + " were given"}; + ADT_LET_CONST_REF(imag, real_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + + "the argument 2 of DataValue.complex128() should be a float64, " + "but a " + + imag_val.GetType().Name() + " were given"}; + return DataValue{axpr::complex128(real, imag)}; + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/data_value_util.h b/paddle/ap/include/axpr/data_value_util.h new file mode 100644 index 00000000000000..b2de3fb1e5acfe --- /dev/null +++ b/paddle/ap/include/axpr/data_value_util.h @@ -0,0 +1,115 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/binary_func.h" +#include "paddle/ap/include/axpr/data_type.h" +#include "paddle/ap/include/axpr/data_value.h" +#include "paddle/ap/include/axpr/unary_func.h" + +namespace ap::axpr { + +namespace detail { + +template +struct ArithmeticUnaryFuncHelper { + static Result Call(const DataValue& value) { + return value.Match([](auto val) -> Result { + if constexpr (IsArithmeticOpSupported()) { + return ArithmeticOp::Call(val); + } else { + return adt::errors::TypeError{ + std::string() + "unsupported operand type for " + + ArithmeticOp::Name() + ": " + CppDataType{}.Name() + + "."}; + } + }); + } +}; + +template +struct ArithmeticBinaryOpHelper { + template + static Result Call(LhsT lhs, RhsT rhs) { + return ArithmeticOp::Call(lhs, rhs); + } +}; + +template <> +struct ArithmeticBinaryOpHelper { + template + static Result Call(LhsT lhs, RhsT rhs) { + if (rhs == 0) { + return adt::errors::ZeroDivisionError{"division or modulo by zero"}; + } + return ArithmeticDiv::Call(lhs, rhs); + } +}; + +template <> +struct ArithmeticBinaryOpHelper { + template + static Result Call(LhsT lhs, RhsT rhs) { + if constexpr (std::is_integral_v && std::is_integral_v) { + if (rhs == 0) { + return adt::errors::ZeroDivisionError{"division or modulo by zero"}; + } + return ArithmeticMod::Call(lhs, rhs); + } else if constexpr (!std::is_integral_v) { + return adt::errors::TypeError{ + "'%' only support intergral type. 'lhs' is not a intergral type"}; + } else { + return adt::errors::TypeError{ + "'%' only support intergral type. 'rhs' is not a intergral type"}; + } + } +}; + +template +struct ArithmeticBinaryFuncHelper { + static Result Call(const DataValue& lhs_value, + const DataValue& rhs_value) { + const auto& pattern_match = + ::common::Overloaded{[](auto lhs, auto rhs) -> Result { + if constexpr (IsArithmeticOpSupported() && + IsArithmeticOpSupported()) { + return ArithmeticBinaryOpHelper::Call(lhs, rhs); + } else { + return adt::errors::TypeError{ + std::string() + "unsupported operand types for " + + ArithmeticOp::Name() + ": '" + + CppDataType{}.Name() + "' and '" + + CppDataType{}.Name() + "'."}; + } + }}; + return std::visit(pattern_match, lhs_value.variant(), rhs_value.variant()); + } +}; + +} // namespace detail + +template +Result ArithmeticUnaryFunc(const DataValue& value) { + return detail::ArithmeticUnaryFuncHelper::Call(value); +} + +template +Result ArithmeticBinaryFunc(const DataValue& lhs_value, + const DataValue& rhs_value) { + return detail::ArithmeticBinaryFuncHelper::Call(lhs_value, + rhs_value); +} + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/dim_expr.h b/paddle/ap/include/axpr/dim_expr.h new file mode 100644 index 00000000000000..7ccd9e12250137 --- /dev/null +++ b/paddle/ap/include/axpr/dim_expr.h @@ -0,0 +1,28 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" + +namespace ap::axpr { + +template +axpr::TypeImpl> GetDimExprClass(); + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/dim_expr_method_class.h b/paddle/ap/include/axpr/dim_expr_method_class.h new file mode 100644 index 00000000000000..d4f39eceff7721 --- /dev/null +++ b/paddle/ap/include/axpr/dim_expr_method_class.h @@ -0,0 +1,102 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/callable_helper.h" +#include "paddle/ap/include/axpr/interpreter_base.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/packed_args.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" + +namespace ap::axpr { + +template +struct DimExprMethodClass { + using This = DimExprMethodClass; + using Self = symbol::DimExpr; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return symbol::ToString(self); + } + + static adt::Result Hash(const ValueT& self_val, + const std::vector&) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + int64_t hash_value = std::hash()(self); + return hash_value; + } + + static adt::Result Match(axpr::InterpreterBase* interpreter, + const ValueT& self_val, + const std::vector& packed_args_val) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const auto& packed_args = axpr::CastToPackedArgs(packed_args_val); + const auto& [args, kwargs] = *packed_args; + ADT_CHECK(args->size() == 0) << adt::errors::TypeError{ + std::string() + + "DimExpr.match() supports keyword arguments only, but " + + std::to_string(args->size()) + " positional arguments were given"}; + const std::string& type_name = This{}.GetTypeName(self); + std::string key = type_name; + if (!kwargs->Has(type_name)) { + if (!kwargs->Has("_")) { + return adt::errors::TypeError{std::string() + + "DimExpr.match() failed. no keyword '" + + type_name + "' or '_' provided"}; + } + key = "_"; + } + ADT_LET_CONST_REF(func, kwargs->Get(key)); + ADT_CHECK(axpr::CallableHelper{}.IsCallable(func)) + << adt::errors::TypeError{ + std::string() + + "the arguments of DimExpr.match() should be callable"}; + if (key == "_") { + return interpreter->InterpretCall(func, {}); + } else { + const auto& make_args = self.Match( + [&](int64_t c) -> adt::List { return adt::List{c}; }, + [&](const std::string& c) -> adt::List { + return adt::List{c}; + }, + [&](const auto&) -> adt::List { return adt::List{}; }); + return interpreter->InterpretCall(func, make_args.vector()); + } + } + + const char* GetTypeName(const symbol::DimExpr& dim_expr) const { + return dim_expr.Match( + [](int64_t) -> const char* { return "int64"; }, + [&](const std::string&) -> const char* { return "symbol"; }, + [&](const auto&) -> const char* { return "_"; }); + } +}; + +template +axpr::TypeImpl> GetDimExprClass() { + using Impl = DimExprMethodClass; + static auto cls( + axpr::MakeBuiltinClass("DimExpr", [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("match", &Impl::Match); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/environment.h b/paddle/ap/include/axpr/environment.h new file mode 100644 index 00000000000000..0f0fb5bd994f54 --- /dev/null +++ b/paddle/ap/include/axpr/environment.h @@ -0,0 +1,45 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/frame.h" +#include "paddle/ap/include/axpr/serializable_value.h" + +namespace ap::axpr { + +template +class Environment { + public: + virtual adt::Result Get(const std::string& var) const = 0; + + virtual adt::Result Set(const std::string& var, + const ValueT& val) = 0; + + virtual std::optional> GetConstGlobalFrame() const { + return std::nullopt; + } + + virtual std::optional> + RecursivelyGetConstGlobalFrame() const = 0; + + protected: + Environment() = default; + Environment(const Environment&) = delete; + Environment(Environment&&) = delete; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/error.h b/paddle/ap/include/axpr/error.h new file mode 100644 index 00000000000000..78c09da4d50172 --- /dev/null +++ b/paddle/ap/include/axpr/error.h @@ -0,0 +1,36 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" + +namespace ap::axpr { + +using adt::errors::AttributeError; +using adt::errors::Error; +using adt::errors::IndexError; +using adt::errors::InvalidArgumentError; +using adt::errors::NameError; +using adt::errors::NotImplementedError; +using adt::errors::RuntimeError; +using adt::errors::SyntaxError; +using adt::errors::TypeError; +using adt::errors::ValueError; +using adt::errors::ZeroDivisionError; + +template +using Result = adt::Result; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/exception_method_class.h b/paddle/ap/include/axpr/exception_method_class.h new file mode 100644 index 00000000000000..51232fdfc22936 --- /dev/null +++ b/paddle/ap/include/axpr/exception_method_class.h @@ -0,0 +1,61 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" + +namespace ap::axpr { + +ADT_DEFINE_TAG(tWrapErrorAsValue); + +using Exception = tWrapErrorAsValue; + +axpr::TypeImpl> GetExceptionClass(); + +template +adt::Result ConstructException( + const axpr::Value&, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(msg, args.at(0).template CastTo()); + adt::errors::Error error{ExceptionImpl{msg}}; + return GetExceptionClass().New(Exception{error}); +} + +template +void YieldExceptionConstructor(const YieldT& Yield) { + Yield(ExceptionImpl{}.class_name(), &ConstructException); +} + +template +void ForEachExceptionConstructor(const YieldT& Yield) { + YieldExceptionConstructor(Yield); + YieldExceptionConstructor(Yield); + YieldExceptionConstructor(Yield); + YieldExceptionConstructor(Yield); + YieldExceptionConstructor(Yield); + YieldExceptionConstructor(Yield); + YieldExceptionConstructor(Yield); + YieldExceptionConstructor(Yield); + YieldExceptionConstructor(Yield); + YieldExceptionConstructor(Yield); + YieldExceptionConstructor(Yield); + YieldExceptionConstructor(Yield); + YieldExceptionConstructor(Yield); + YieldExceptionConstructor(Yield); +} + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/float.h b/paddle/ap/include/axpr/float.h new file mode 100644 index 00000000000000..41e874a583b5a3 --- /dev/null +++ b/paddle/ap/include/axpr/float.h @@ -0,0 +1,31 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template <> +struct TypeImpl : public std::monostate { + using value_type = double; + + const char* Name() const { return "float"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/float_method_class.h b/paddle/ap/include/axpr/float_method_class.h new file mode 100644 index 00000000000000..fddb0045c271e4 --- /dev/null +++ b/paddle/ap/include/axpr/float_method_class.h @@ -0,0 +1,133 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/bool_int_double_arithmetic_util.h" +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +struct FloatMethodClass { + using This = FloatMethodClass; + using Self = double; + + adt::Result ToString(const Self val) { return std::to_string(val); } + + adt::Result Hash(const Self val) { + return static_cast(std::hash()(val)); + } + + template + static BuiltinUnaryFunc GetBuiltinUnaryFunc() { + if constexpr (ConvertBuiltinSymbolToArithmetic< + BuiltinUnarySymbol>::convertible) { + using ArithmeticOp = typename ConvertBuiltinSymbolToArithmetic< + BuiltinUnarySymbol>::arithmetic_op_type; + return &This::UnaryFunc; + } else { + return adt::Nothing{}; + } + } + + template + static BuiltinBinaryFunc GetBuiltinBinaryFunc() { + if constexpr (ConvertBuiltinSymbolToArithmetic< + BuiltinBinarySymbol>::convertible) { + using ArithmeticOp = typename ConvertBuiltinSymbolToArithmetic< + BuiltinBinarySymbol>::arithmetic_op_type; + return &This::template BinaryFunc; + } else { + return adt::Nothing{}; + } + } + + template + static adt::Result BinaryFunc(const ValueT& lhs_val, + const ValueT& rhs_val) { + ADT_LET_CONST_REF(lhs, lhs_val.template TryGet()); + return rhs_val.Match( + [&](bool rhs) -> adt::Result { + return BoolIntDoubleArithmeticBinaryFunc(lhs, + rhs); + }, + [&](int64_t rhs) -> adt::Result { + return BoolIntDoubleArithmeticBinaryFunc(lhs, + rhs); + }, + [&](double rhs) -> adt::Result { + return BoolIntDoubleArithmeticBinaryFunc(lhs, + rhs); + }, + [&](const auto& impl) -> adt::Result { + using T = std::decay_t; + return adt::errors::TypeError{std::string() + + "unsupported operand type(s) for " + + ArithmeticOp::Name() + ": 'int' and '" + + axpr::GetTypeName(rhs_val) + "'"}; + }); + } + + template + static adt::Result UnaryFunc(const ValueT& val) { + ADT_LET_CONST_REF(operand, val.template TryGet()); + return BoolIntDoubleArithmeticUnaryFunc(operand); + } +}; + +template +struct MethodClassImpl : public FloatMethodClass {}; + +template +struct MethodClassImpl> { + using This = MethodClassImpl>; + + adt::Result Call(const TypeImpl&) { return &This::Construct; } + + static adt::Result Construct(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "float() takes 1 argument, but " + + std::to_string(args.size()) + " were given"}; + using T = double; + using RetT = adt::Result; + return args.at(0).Match( + [&](bool c) -> RetT { return static_cast(c); }, + [&](int64_t c) -> RetT { return static_cast(c); }, + [&](double c) -> RetT { return static_cast(c); }, + [&](DataValue data_value) -> RetT { + return data_value.Match( + [&](const axpr::pstring&) -> RetT { + return adt::errors::TypeError{ + "invalid conversion from type 'pstring' to 'float'"}; + }, + [&](const adt::Undefined&) -> RetT { + return adt::errors::TypeError{ + "invalid conversion from type 'void' to 'float'"}; + }, + [&](const auto& impl) -> RetT { return static_cast(impl); }); + }, + [&](const auto&) -> adt::Result { + return adt::errors::TypeError{ + std::string() + + "the argument 1 of float() should be bool/int/float/DataValue"}; + }); + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/frame.h b/paddle/ap/include/axpr/frame.h new file mode 100644 index 00000000000000..a6a5945dbd7d19 --- /dev/null +++ b/paddle/ap/include/axpr/frame.h @@ -0,0 +1,27 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/memory/circlable_ref.h" + +namespace ap::axpr { + +template +struct Frame : public memory::CirclableRef, AttrMapImpl> { + using memory::CirclableRef, AttrMapImpl>::CirclableRef; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/function.h b/paddle/ap/include/axpr/function.h new file mode 100644 index 00000000000000..22405e2f0a9df4 --- /dev/null +++ b/paddle/ap/include/axpr/function.h @@ -0,0 +1,51 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/frame.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +struct FunctionImpl { + Lambda lambda; + std::optional> global_frame; + + bool operator==(const FunctionImpl& other) const { return this == &other; } + + adt::Result GetHashValue() const { + int64_t hash_value = reinterpret_cast(lambda.shared_ptr().get()); + if (global_frame.has_value()) { + ADT_LET_CONST_REF(global_frame_ptr, global_frame.value().Get()); + int64_t frame_hash_value = reinterpret_cast(global_frame_ptr); + hash_value = adt::hash_combine(hash_value, frame_hash_value); + } + return hash_value; + } +}; + +template +ADT_DEFINE_RC(Function, FunctionImpl); + +template +struct TypeImpl> : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "function"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/function_method_class.h b/paddle/ap/include/axpr/function_method_class.h new file mode 100644 index 00000000000000..4938e9d3156289 --- /dev/null +++ b/paddle/ap/include/axpr/function_method_class.h @@ -0,0 +1,61 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/anf_expr_helper.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/serializable_value.h" + +namespace ap::axpr { + +template +struct MethodClassImpl> + : public EmptyMethodClass { + using Self = Function; + + adt::Result ToString(const Self& function) { + const auto& lambda = function->lambda; + const auto& anf_expr = ConvertCoreExprToAnfExpr(lambda); + ADT_LET_CONST_REF(anf_atomic, anf_expr.template TryGet>()); + ADT_LET_CONST_REF(anf_lambda, + anf_atomic.template TryGet>()); + AnfExprHelper anf_expr_helper; + ADT_LET_CONST_REF(anf_expr_str, + anf_expr_helper.FunctionToString(anf_lambda)); + return anf_expr_str; + } + + adt::Result Hash(const Self& function) { + ADT_LET_CONST_REF(hash_value, function->GetHashValue()); + return hash_value; + } + + adt::Result GetAttr(const Self& self, const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, TryGetImpl(attr_name_val)); + if (attr_name == "__function__") { + return self; + } + return adt::errors::AttributeError{ + std::string() + "function has not attribute '" + attr_name + "'."}; + } +}; + +template +struct MethodClassImpl>> + : public EmptyMethodClass {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/global_environment.h b/paddle/ap/include/axpr/global_environment.h new file mode 100644 index 00000000000000..41b9c3f397908d --- /dev/null +++ b/paddle/ap/include/axpr/global_environment.h @@ -0,0 +1,72 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/environment.h" + +namespace ap::axpr { + +template +class GlobalEnvironment : public Environment { + public: + adt::Result Get(const std::string& var) const override { + ADT_LET_CONST_REF(frame_ptr, frame_.Get()); + const auto& res = frame_ptr->OptGet(var); + if (res.has_value()) { + return res.value(); + } + if (parent_ == nullptr) { + return NameError{std::string("name '") + var + "' is not defined."}; + } + return parent_->Get(var); + } + + adt::Result Set(const std::string& var, const ValueT& val) override { + ADT_LET_CONST_REF(frame_ptr, frame_.Mut()); + { + static std::string tmp_var_prefix("__"); + if (var.substr(0, tmp_var_prefix.size()) != tmp_var_prefix) { + ADT_CHECK(SerializableValue::IsSerializable(val)) << [&] { + std::ostringstream ss; + ss << "Only serializable values are supported insert into global " + "environment. " ss + << "Builtin serializable types are: "; + ss << SerializableValue::SerializableTypeNames(); + ss << " (not include '" << axpr::GetTypeName(val) << "')."; + return adt::errors::ValueError{ss.str()}; + }(); + } + } + frame_ptr->Set(var, val); + return adt::Ok{}; + } + + const Frame& frame() const { return frame_; } + + GlobalEnvironment(const std::shared_ptr>& parent, + const Frame& frame) + : parent_(parent), frame_(frame) {} + + private: + GlobalEnvironment(const GlobalEnvironment&) = delete; + GlobalEnvironment(GlobalEnvironment&&) = delete; + + std::shared_ptr> parent_; + Frame frame_; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/hash.h b/paddle/ap/include/axpr/hash.h new file mode 100644 index 00000000000000..f40b3580d5f236 --- /dev/null +++ b/paddle/ap/include/axpr/hash.h @@ -0,0 +1,48 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/method_class.h" + +namespace ap::axpr { + +template +struct Hash { + adt::Result operator()(InterpreterBase* interpreter, + const ValueT& val) const { + const auto& func = MethodClass::Hash(val); + using RetT = adt::Result; + return func.Match( + [&](const adt::Nothing&) -> RetT { + return adt::errors::TypeError{GetTypeName(val) + + " class has no __hash__ function."}; + }, + [&](adt::Result (*unary_func)(const ValueT&)) -> RetT { + ADT_LET_CONST_REF(hash_val, unary_func(val)); + ADT_LET_CONST_REF(hash, hash_val.template TryGet()); + return hash; + }, + [&](adt::Result (*unary_func)(InterpreterBase*, + const ValueT&)) -> RetT { + ADT_LET_CONST_REF(hash_val, unary_func(interpreter, val)); + ADT_LET_CONST_REF(hash, hash_val.template TryGet()); + return hash; + }); + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/instance_attrs.h b/paddle/ap/include/axpr/instance_attrs.h new file mode 100644 index 00000000000000..15ca2ab5ac7340 --- /dev/null +++ b/paddle/ap/include/axpr/instance_attrs.h @@ -0,0 +1,30 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/memory/circlable_ref.h" + +namespace ap::axpr { + +template +struct InstanceAttrs + : public memory::CirclableRef, AttrMapImpl> { + using Base = memory::CirclableRef, AttrMapImpl>; + using Base::CirclableRef; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/int.h b/paddle/ap/include/axpr/int.h new file mode 100644 index 00000000000000..934f4cc0366c5f --- /dev/null +++ b/paddle/ap/include/axpr/int.h @@ -0,0 +1,31 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template <> +struct TypeImpl : public std::monostate { + using value_type = int64_t; + + const char* Name() const { return "int"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/int_data_type.h b/paddle/ap/include/axpr/int_data_type.h new file mode 100644 index 00000000000000..0609aa1b4857a9 --- /dev/null +++ b/paddle/ap/include/axpr/int_data_type.h @@ -0,0 +1,25 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#define AP_FOR_EACH_INT_TYPE(_) \ + _(int8) \ + _(uint8) \ + _(int16) \ + _(uint16) \ + _(int32) \ + _(uint32) \ + _(int64) \ + _(uint64) diff --git a/paddle/ap/include/axpr/int_method_class.h b/paddle/ap/include/axpr/int_method_class.h new file mode 100644 index 00000000000000..eafb83714c12da --- /dev/null +++ b/paddle/ap/include/axpr/int_method_class.h @@ -0,0 +1,133 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/bool_int_double_arithmetic_util.h" +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/data_value.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +struct IntMethodClass { + using This = IntMethodClass; + using Self = int64_t; + + adt::Result ToString(Self int_val) { return std::to_string(int_val); } + + adt::Result Hash(Self int_val) { return int_val; } + + template + static BuiltinUnaryFunc GetBuiltinUnaryFunc() { + if constexpr (ConvertBuiltinSymbolToArithmetic< + BuiltinUnarySymbol>::convertible) { + using ArithmeticOp = typename ConvertBuiltinSymbolToArithmetic< + BuiltinUnarySymbol>::arithmetic_op_type; + return &This::UnaryFunc; + } else { + return adt::Nothing{}; + } + } + + template + static BuiltinBinaryFunc GetBuiltinBinaryFunc() { + if constexpr (ConvertBuiltinSymbolToArithmetic< + BuiltinBinarySymbol>::convertible) { + using ArithmeticOp = typename ConvertBuiltinSymbolToArithmetic< + BuiltinBinarySymbol>::arithmetic_op_type; + return &This::template BinaryFunc; + } else { + return adt::Nothing{}; + } + } + + template + static adt::Result BinaryFunc(const ValueT& lhs_val, + const ValueT& rhs_val) { + ADT_LET_CONST_REF(lhs, lhs_val.template TryGet()); + return rhs_val.Match( + [&](bool rhs) -> adt::Result { + return BoolIntDoubleArithmeticBinaryFunc(lhs, + rhs); + }, + [&](int64_t rhs) -> adt::Result { + return BoolIntDoubleArithmeticBinaryFunc(lhs, + rhs); + }, + [&](double rhs) -> adt::Result { + return BoolIntDoubleArithmeticBinaryFunc(lhs, + rhs); + }, + [&](const auto& impl) -> adt::Result { + return adt::errors::TypeError{std::string() + + "unsupported operand type(s) for " + + ArithmeticOp::Name() + ": 'int' and '" + + axpr::GetTypeName(rhs_val) + "'"}; + }); + } + + template + static adt::Result UnaryFunc(const ValueT& val) { + ADT_LET_CONST_REF(operand, val.template TryGet()); + return BoolIntDoubleArithmeticUnaryFunc(operand); + } +}; + +template +struct MethodClassImpl : public IntMethodClass {}; + +template +struct MethodClassImpl> { + using This = MethodClassImpl>; + + adt::Result Call(const TypeImpl&) { + return &This::Construct; + } + + static adt::Result Construct(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "int() takes 1 argument, but " + + std::to_string(args.size()) + " were given"}; + using T = int64_t; + using RetT = adt::Result; + return args.at(0).Match( + [&](bool c) -> RetT { return static_cast(c); }, + [&](int64_t c) -> RetT { return static_cast(c); }, + [&](double c) -> RetT { return static_cast(c); }, + [&](DataValue data_value) -> RetT { + return data_value.Match( + [&](const axpr::pstring&) -> RetT { + return adt::errors::TypeError{ + "invalid conversion from type 'pstring' to 'int'"}; + }, + [&](const adt::Undefined&) -> RetT { + return adt::errors::TypeError{ + "invalid conversion from type 'void' to 'int'"}; + }, + [&](const auto& impl) -> RetT { return static_cast(impl); }); + }, + [&](const auto&) -> adt::Result { + return adt::errors::TypeError{ + std::string() + + "the argument 1 of int() should be bool/int/float/DataValue"}; + }); + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/interpreter.h b/paddle/ap/include/axpr/interpreter.h new file mode 100644 index 00000000000000..6ad91b1e6d8b21 --- /dev/null +++ b/paddle/ap/include/axpr/interpreter.h @@ -0,0 +1,43 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/value.h" + +namespace ap::axpr { + +class Interpreter { + public: + explicit Interpreter( + const axpr::AttrMap& builtin_frame_attr_map, + const std::weak_ptr& circlable_ref_list) + : builtin_frame_attr_map_(builtin_frame_attr_map), + circlable_ref_list_(circlable_ref_list) {} + + adt::Result Interpret(const Lambda& lambda, + const std::vector& args); + adt::Result Interpret(const axpr::Value& function, + const std::vector& args); + + adt::Result InterpretModule( + const Frame& const_global_frame, + const Lambda& lambda); + + private: + axpr::AttrMap builtin_frame_attr_map_; + std::weak_ptr circlable_ref_list_; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/interpreter_base.h b/paddle/ap/include/axpr/interpreter_base.h new file mode 100644 index 00000000000000..2b112997ce2561 --- /dev/null +++ b/paddle/ap/include/axpr/interpreter_base.h @@ -0,0 +1,53 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/atomic.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/frame.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/memory/circlable_ref_list_base.h" + +namespace ap::axpr { + +template +class Environment; + +struct SerializableValue; + +template +class InterpreterBase { + public: + virtual Result InterpretCall(const ValueT& func, + const std::vector& args) = 0; + + virtual Result InterpretModule( + const Frame& const_global_frame, + const Lambda& lambda) = 0; + + virtual std::weak_ptr circlable_ref_list() + const = 0; + + virtual Result InterpretLambdaCall( + const std::shared_ptr>& env, + const ValueT& outer_func, + const Lambda& lambda, + const std::vector& args, + ComposedCallImpl* ret_composed_call) = 0; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/lambda_expr_builder.h b/paddle/ap/include/axpr/lambda_expr_builder.h new file mode 100644 index 00000000000000..8f51c81c6494dc --- /dev/null +++ b/paddle/ap/include/axpr/lambda_expr_builder.h @@ -0,0 +1,346 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/ap/include/axpr/anf_expr_builder.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/common/enforce.h" + +namespace ap::axpr { + +class LetContext; + +class LetVar { + public: + LetVar(const LetVar&) = default; + LetVar(LetVar&&) = default; + + LetVar& operator=(const LetVar& let_var); + LetVar& operator=(const AnfExpr& anf_val); + + const std::string& name() const { return name_; } + + operator Atomic() const { return tVar{name()}; } + + LetVar& Attr(const std::string& attr_name) { + return AttrImpl(Atomic{attr_name}); + } + LetVar& Attr(const LetVar& attr_name) { + return AttrImpl(static_cast>(attr_name)); + } + void SetAttr(const std::string& attr_name, const AnfExpr& anf_expr) { + return SetAttrImpl(Atomic{attr_name}, anf_expr); + } + void SetAttr(const LetVar& attr_name, const AnfExpr& anf_expr) { + return SetAttrImpl(static_cast>(attr_name), anf_expr); + } + void SetAttrImpl(const Atomic& attr_name, const AnfExpr& anf_expr); + LetVar& At(int64_t idx); + LetVar& At(const Atomic& idx); + + template + LetVar& Call(Args&&... args); + + template + LetVar& Apply(Args&&... args); + + LetContext* ctx() const { return let_ctx_; } + + private: + friend class LetContext; + LetVar(LetContext* let_ctx, const std::string& name) + : let_ctx_(let_ctx), name_(name) {} + + LetVar& AttrImpl(const Atomic& attr_name); + + LetContext* let_ctx_; + std::string name_; +}; + +class LetContext : public AtomicExprBuilder { + public: + explicit LetContext(const std::function& SeqNoGenerator) + : SeqNoGenerator_(SeqNoGenerator) {} + LetContext(const LetContext&) = delete; + LetContext(LetContext&&) = delete; + + using var_type = LetVar; + + LetVar& Var(const std::string& name) { + auto iter = let_var_storage_.find(name); + if (iter == let_var_storage_.end()) { + auto var = std::unique_ptr(new LetVar(this, name)); + iter = let_var_storage_.emplace(name, std::move(var)).first; + } + return *iter->second; + } + + template + AnfExpr Call(const LetVar& f, Arg0 arg0, Args&&... args) { + return ApplyImpl(f.name(), + std::vector{std::forward(arg0), + std::forward(args)...}); + } + + template + AnfExpr Call(const std::string& f, Arg0 arg0, Args&&... args) { + return ApplyImpl(f, + std::vector{std::forward(arg0), + std::forward(args)...}); + } + + AnfExpr Call(const LetVar& f) { + return ApplyImpl(f.name(), std::vector{}); + } + + AnfExpr Call(const std::string& f) { + return ApplyImpl(f, std::vector{}); + } + + AnfExpr Apply(const LetVar& f, const std::vector& vars) { + return Apply(f.name(), vars); + } + + AnfExpr Apply(const std::string& f, const std::vector& vars) { + std::vector args; + args.reserve(vars.size()); + for (const auto& var : vars) { + args.emplace_back(var); + } + return ApplyImpl(f, args); + } + + AnfExpr Apply(const LetVar& f, const std::vector& args) { + return ApplyImpl(f.name(), args); + } + + AnfExpr Apply(const std::string& f, const std::vector& args) { + return ApplyImpl(f, args); + } + + AnfExpr Apply(const LetVar& f, + const std::vector& args, + const std::map& kwargs) { + return Apply(f.name(), args, kwargs); + } + + AnfExpr Apply(const std::string& f, + const std::vector& args, + const std::map& kwargs) { + std::vector kwarg_list; + for (const auto& [keyword, val] : kwargs) { + const AnfExpr& item = + this->Call(ap::axpr::kBuiltinList(), this->String(keyword), val); + kwarg_list.emplace_back(item); + } + const AnfExpr& packed_args = + this->Call(this->Var("__builtin_PackedArgs__"), + this->Apply(ap::axpr::kBuiltinList(), args), + this->Apply(ap::axpr::kBuiltinList(), kwarg_list)); + return this->Call(f, packed_args); + } + + LetVar& Attr(const AnfExpr& self, const std::string& attr_name) { + const auto& var_name = NewTmpVarName(); + Var(var_name) = self; + return Var(var_name).Attr(attr_name); + } + + const std::vector>& bindings() { return bindings_; } + + std::string NewTmpVarName() { + static const std::string prefix = "___"; + return prefix + std::to_string(SeqNoGenerator_()); + } + + private: + friend class LetVar; + + AnfExpr ApplyImpl(const std::string& f, const std::vector& args) { + std::vector> atomic_args; + atomic_args.reserve(args.size()); + for (const auto& anf_expr : args) { + anf_expr.Match( + [&](const Atomic& atomic) { atomic_args.push_back(atomic); }, + [&](const auto&) { atomic_args.push_back(BindToTmpVar(anf_expr)); }); + } + return AnfExprBuilder().Call(tVar{f}, atomic_args); + } + + tVar BindToTmpVar(const AnfExpr& anf_val) { + const tVar tmp_var_name{NewTmpVarName()}; + AddBinding(tmp_var_name.value(), anf_val); + return tmp_var_name; + } + + void AddBinding(const std::string& name, const AnfExpr& anf_val) { + AnfExprBuilder anf; + anf_val.Match( + [&](const Atomic& atomic) { + const auto& combined = + anf.Call(tVar{kBuiltinIdentity()}, {atomic}); + bindings_.push_back(anf.Bind(name, combined)); + }, + [&](const Combined& combined) { + bindings_.push_back(anf.Bind(name, combined)); + }, + [&](const Let& let) { + const auto& lambda = anf.Lambda({}, let); + const auto& combined = anf.Call(lambda, {}); + bindings_.push_back(anf.Bind(name, combined)); + }); + } + + std::unordered_map> let_var_storage_; + std::vector> bindings_; + std::function SeqNoGenerator_; +}; + +inline LetVar& LetVar::operator=(const LetVar& let_var) { + AnfExprBuilder anf{}; + return *this = anf.Call(tVar{kBuiltinIdentity()}, + {tVar{let_var.name()}}); +} + +inline LetVar& LetVar::operator=(const AnfExpr& anf_val) { + let_ctx_->AddBinding(name_, anf_val); + return *this; +} + +inline LetVar& LetVar::AttrImpl(const Atomic& attr_name) { + AnfExprBuilder anf{}; + AnfExpr anf_expr = anf.Call(tVar{kBuiltinGetAttr()}, + {tVar{name()}, attr_name}); + return let_ctx_->Var(let_ctx_->BindToTmpVar(anf_expr).value()); +} + +inline void LetVar::SetAttrImpl(const Atomic& attr_name, + const AnfExpr& val) { + const auto& atomic = val.Match( + [&](const Atomic& atomic_val) -> Atomic { + return atomic_val; + }, + [&](const auto& impl) -> Atomic { + return let_ctx_->BindToTmpVar(val); + }); + AnfExprBuilder anf{}; + const auto& method_anf_expr = + anf.Call(tVar{kBuiltinSetAttr()}, + {tVar{name()}, attr_name}); + const auto& method = let_ctx_->BindToTmpVar(method_anf_expr); + AnfExpr anf_expr = anf.Call(method, {attr_name, atomic}); + let_ctx_->BindToTmpVar(anf_expr); +} + +inline LetVar& LetVar::At(int64_t idx) { + AnfExprBuilder anf{}; + AnfExpr anf_expr = anf.Call(tVar{kBuiltinGetItem()}, + {tVar{name()}, anf.Int64(idx)}); + return let_ctx_->Var(let_ctx_->BindToTmpVar(anf_expr).value()); +} + +inline LetVar& LetVar::At(const Atomic& idx) { + AnfExprBuilder anf{}; + AnfExpr anf_expr = anf.Call(tVar{kBuiltinGetItem()}, + {tVar{name()}, idx}); + return let_ctx_->Var(let_ctx_->BindToTmpVar(anf_expr).value()); +} + +template +inline LetVar& LetVar::Call(Args&&... args) { + const auto& anf_expr = let_ctx_->Call(*this, std::forward(args)...); + return let_ctx_->Var(let_ctx_->BindToTmpVar(anf_expr).value()); +} + +template +inline LetVar& LetVar::Apply(Args&&... args) { + const auto& anf_expr = let_ctx_->Apply(*this, std::forward(args)...); + return let_ctx_->Var(let_ctx_->BindToTmpVar(anf_expr).value()); +} + +class LambdaExprBuilder { + public: + LambdaExprBuilder() : SeqNoGenerator_(&LambdaExprBuilder::GenSeqNo) {} + explicit LambdaExprBuilder(const std::function& SeqNoGenerator) + : SeqNoGenerator_(SeqNoGenerator) {} + LambdaExprBuilder(const LambdaExprBuilder&) = delete; + LambdaExprBuilder(LambdaExprBuilder&&) = delete; + + AnfExpr Lambda(const std::vector& args, + const std::function& GetBody) { + AnfExpr anf_expr = Let(GetBody); + AnfExpr lambda_or_body = anf_expr.Match( + [&](const ap::axpr::Let& let) { + if (let->bindings.empty()) { + return let->body; + } else { + return anf_expr; + } + }, + [&](const auto&) { return anf_expr; }); + return anf_.Lambda(MakeLambdaArgs(args), lambda_or_body); + } + + AnfExpr Let(const std::function& GetBody) { + LetContext let_ctx{SeqNoGenerator_}; + AnfExpr ret = GetBody(let_ctx); + return anf_.Let(let_ctx.bindings(), ret); + } + + adt::Result TryLambda( + const std::vector& args, + const std::function(LetContext&)>& GetBody) { + ADT_LET_CONST_REF(anf_expr, TryLet(GetBody)); + AnfExpr lambda_or_body = anf_expr.Match( + [&](const ap::axpr::Let& let) { + if (let->bindings.empty()) { + return let->body; + } else { + return anf_expr; + } + }, + [&](const auto&) { return anf_expr; }); + return anf_.Lambda(MakeLambdaArgs(args), lambda_or_body); + } + + adt::Result TryLet( + const std::function(LetContext&)>& GetBody) { + LetContext let_ctx{SeqNoGenerator_}; + ADT_LET_CONST_REF(ret, GetBody(let_ctx)); + return anf_.Let(let_ctx.bindings(), ret); + } + + std::vector> MakeLambdaArgs( + const std::vector& args) { + std::vector> lambda_args; + lambda_args.reserve(args.size()); + for (const auto& arg : args) { + lambda_args.emplace_back(arg); + } + return lambda_args; + } + + private: + static size_t GenSeqNo() { + static std::atomic seq_no(0); + return seq_no++; + } + + std::function SeqNoGenerator_; + AnfExprBuilder anf_; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/list.h b/paddle/ap/include/axpr/list.h new file mode 100644 index 00000000000000..122be7b3135886 --- /dev/null +++ b/paddle/ap/include/axpr/list.h @@ -0,0 +1,30 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +struct TypeImpl> : public std::monostate { + using value_type = adt::List; + + const char* Name() const { return "list"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/list_method_class.h b/paddle/ap/include/axpr/list_method_class.h new file mode 100644 index 00000000000000..6d2a72ed2af984 --- /dev/null +++ b/paddle/ap/include/axpr/list_method_class.h @@ -0,0 +1,147 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/starred.h" + +namespace ap::axpr { + +template +struct MethodClassImpl> { + using This = MethodClassImpl>; + using Self = adt::List; + + adt::Result Length(const Self& self) { + return static_cast(self->size()); + } + + adt::Result ToString(axpr::InterpreterBase* interpreter, + const Self& self) { + std::ostringstream ss; + ss << "["; + int i = 0; + using Ok = adt::Result; + for (const auto& elt : *self) { + if (i++ > 0) { + ss << ", "; + } + const auto& func = MethodClass::ToString(elt); + ADT_RETURN_IF_ERR(func.Match( + [&](const adt::Nothing&) -> Ok { + return adt::errors::TypeError{GetTypeName(elt) + + " class has no __str__ method"}; + }, + [&](adt::Result (*unary_func)(const ValueT&)) -> Ok { + ADT_LET_CONST_REF(str_val, unary_func(elt)); + ADT_LET_CONST_REF(str, str_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + "'" + axpr::GetTypeName(elt) + + ".__str__ should return a 'str' but '" + + axpr::GetTypeName(str_val) + "' were returned."}; + ss << str; + return adt::Ok{}; + }, + [&](adt::Result (*unary_func)(axpr::InterpreterBase*, + const ValueT&)) -> Ok { + ADT_LET_CONST_REF(str_val, unary_func(interpreter, elt)); + ADT_LET_CONST_REF(str, str_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + "'" + axpr::GetTypeName(elt) + + ".__str__ should return a 'str' but '" + + axpr::GetTypeName(str_val) + "' were returned."}; + ss << str; + return adt::Ok{}; + })); + } + ss << "]"; + return ss.str(); + } + + static Result EQ(const ValueT& lhs_val, const ValueT& rhs_val) { + ADT_LET_CONST_REF(lhs, lhs_val.template TryGet()); + ADT_LET_CONST_REF(rhs, rhs_val.template TryGet()); + return lhs == rhs; + } + + adt::Result Hash(axpr::InterpreterBase* interpreter, + const Self& self) { + int64_t hash_value = 0; + using Ok = adt::Result; + for (const auto& elt : *self) { + const auto& func = MethodClass::Hash(elt); + ADT_RETURN_IF_ERR(func.Match( + [&](const adt::Nothing&) -> Ok { + return adt::errors::TypeError{std::string() + GetTypeName(elt) + + " class has no __hash__ method"}; + }, + [&](adt::Result (*unary_func)(const ValueT&)) -> Ok { + ADT_LET_CONST_REF(elt_hash_val, unary_func(elt)); + ADT_LET_CONST_REF(elt_hash, elt_hash_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + "'" + axpr::GetTypeName(elt) + + ".__hash__ should return a 'int' but '" + + axpr::GetTypeName(elt_hash_val) + "' were returned."}; + hash_value = adt::hash_combine(hash_value, elt_hash); + return adt::Ok{}; + }, + [&](adt::Result (*unary_func)(axpr::InterpreterBase*, + const ValueT&)) -> Ok { + ADT_LET_CONST_REF(elt_hash_val, unary_func(interpreter, elt)); + ADT_LET_CONST_REF(elt_hash, elt_hash_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + "'" + axpr::GetTypeName(elt) + + ".__hash__ should return a 'int' but '" + + axpr::GetTypeName(elt_hash_val) + "' were returned."}; + hash_value = adt::hash_combine(hash_value, elt_hash); + return adt::Ok{}; + })); + } + return hash_value; + } + + adt::Result GetItem(const Self& self, const ValueT& idx) { + return idx.Match( + [&](int64_t index) -> Result { + if (index < 0) { + index += self->size(); + } + if (index >= 0 && index < self->size()) { + return self->at(index); + } + return adt::errors::IndexError{"list index out of range"}; + }, + [&](const auto&) -> Result { + return adt::errors::TypeError{std::string() + + "list indices must be integers, not " + + axpr::GetTypeName(idx)}; + }); + } + + adt::Result Starred(const Self& self) { + return ap::axpr::Starred{self}; + } +}; + +template +struct MethodClassImpl>> { + using Self = TypeImpl>; + + using This = MethodClassImpl; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/method.h b/paddle/ap/include/axpr/method.h new file mode 100644 index 00000000000000..8aca3195119ef9 --- /dev/null +++ b/paddle/ap/include/axpr/method.h @@ -0,0 +1,43 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +struct MethodImpl { + ValueT obj; + ValueT func; + + bool operator==(const MethodImpl& other) const { + return other.obj == this->obj && other.func == this->func; + } +}; + +template +ADT_DEFINE_RC(Method, const MethodImpl); + +template +struct TypeImpl> : public std::monostate { + using value_type = Method; + + const char* Name() const { return "method"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/method_class.h b/paddle/ap/include/axpr/method_class.h new file mode 100644 index 00000000000000..a9fe458a087d47 --- /dev/null +++ b/paddle/ap/include/axpr/method_class.h @@ -0,0 +1,483 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/builtin_func_type.h" +#include "paddle/ap/include/axpr/class_instance.h" +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +class InterpreterBase; + +template +using BuiltinUnaryFuncImpl = + std::variant (*)(const ValueT&), + adt::Result (*)(InterpreterBase*, + const ValueT&)>; + +template +struct BuiltinUnaryFunc : public BuiltinUnaryFuncImpl { + using BuiltinUnaryFuncImpl::BuiltinUnaryFuncImpl; + ADT_DEFINE_VARIANT_METHODS(BuiltinUnaryFuncImpl); +}; + +template BuiltinFunc> +adt::Result UnaryFuncReturnCapturedValue(const ValueT&) { + return BuiltinFunc; +} + +template +using BuiltinBinaryFuncImpl = + std::variant (*)(const ValueT&, const ValueT&), + adt::Result (*)( + InterpreterBase*, const ValueT&, const ValueT&)>; + +template +struct BuiltinBinaryFunc : public BuiltinBinaryFuncImpl { + using BuiltinBinaryFuncImpl::BuiltinBinaryFuncImpl; + ADT_DEFINE_VARIANT_METHODS(BuiltinBinaryFuncImpl); +}; + +template +struct EmptyMethodClass { + template + static BuiltinUnaryFunc GetBuiltinUnaryFunc() { + return adt::Nothing{}; + } + + template + static BuiltinBinaryFunc GetBuiltinBinaryFunc() { + return adt::Nothing{}; + } +}; + +template +struct MethodClassImpl; + +namespace detail { + +template +struct BuiltinMethodHelperImpl; + +#define SPECIALIZE_BuiltinMethodHelperImpl(symbol_name, op) \ + template \ + struct BuiltinMethodHelperImpl { \ + using This = \ + BuiltinMethodHelperImpl; \ + \ + template \ + using UnaryMethodRetT = \ + decltype(std::declval&>().symbol_name( \ + std::declval())); \ + \ + template \ + using HighOrderUnaryMethodRetT = \ + decltype(std::declval&>().symbol_name( \ + std::declval*>(), \ + std::declval())); \ + \ + static constexpr bool HasUnaryMethod() { \ + return builtin_symbol::symbol_name::num_operands == 1 && \ + std::experimental::is_detected_v; \ + } \ + \ + static constexpr bool HasHighOrderUnaryMethod() { \ + return builtin_symbol::symbol_name::num_operands == 1 && \ + std::experimental::is_detected_v; \ + } \ + \ + static adt::Result UnaryCall(const T& obj) { \ + if constexpr (This::HasUnaryMethod()) { \ + return MethodClassImpl{}.symbol_name(obj); \ + } else { \ + return adt::errors::RuntimeError{"`" #symbol_name \ + "` method not found."}; \ + } \ + } \ + \ + static adt::Result HighOrderUnaryCall( \ + InterpreterBase* interpreter, const T& obj) { \ + if constexpr (This::HasHighOrderUnaryMethod()) { \ + return MethodClassImpl{}.symbol_name(interpreter, obj); \ + } else { \ + return adt::errors::RuntimeError{"`" #symbol_name \ + "` method not found."}; \ + } \ + } \ + \ + template \ + using BinaryMethodRetT = \ + decltype(std::declval&>().symbol_name( \ + std::declval(), std::declval())); \ + \ + template \ + using HighOrderBinaryMethodRetT = \ + decltype(std::declval&>().symbol_name( \ + std::declval*>(), \ + std::declval(), \ + std::declval())); \ + \ + static constexpr bool HasBinaryMethod() { \ + return builtin_symbol::symbol_name::num_operands == 2 && \ + std::experimental::is_detected_v; \ + } \ + \ + static constexpr bool HasHighOrderBinaryMethod() { \ + return builtin_symbol::symbol_name::num_operands == 2 && \ + std::experimental::is_detected_v; \ + } \ + \ + static adt::Result BinaryCall(const T& obj, const ValueT& arg) { \ + if constexpr (This::HasBinaryMethod()) { \ + return MethodClassImpl{}.symbol_name(obj, arg); \ + } else { \ + return adt::errors::RuntimeError{"`" #symbol_name \ + "` method not found."}; \ + } \ + } \ + static adt::Result HighOrderBinaryCall( \ + InterpreterBase* interpreter, \ + const T& obj, \ + const ValueT& arg) { \ + if constexpr (This::HasHighOrderBinaryMethod()) { \ + return MethodClassImpl{}.symbol_name( \ + interpreter, obj, arg); \ + } else { \ + return adt::errors::RuntimeError{"`" #symbol_name \ + "` method not found."}; \ + } \ + } \ + }; + +AXPR_FOR_EACH_SYMBOL_OP(SPECIALIZE_BuiltinMethodHelperImpl) + +#undef SPECIALIZE_BuiltinMethodHelperImpl + +template +struct DirectAlternative { + static adt::Result TryGet(const VariantT& val) { + if (val.template Has()) { + return val.template Get(); + } + return adt::errors::TypeError{"cast failed."}; + } +}; + +template +struct IndirectAlternative { + static adt::Result TryGet(const ValueT& val) { + using TypeT = typename TypeTrait::TypeT; + ADT_LET_CONST_REF(type, DirectAlternative::TryGet(val)); + return DirectAlternative::TryGet(type); + } +}; + +template + class Alternative> +struct BuiltinMethodHelper { + using This = BuiltinMethodHelper; + using Impl = BuiltinMethodHelperImpl; + + static constexpr bool HasUnaryMethod() { return Impl::HasUnaryMethod(); } + + static constexpr bool HasHighOrderUnaryMethod() { + return Impl::HasHighOrderUnaryMethod(); + } + + static constexpr BuiltinUnaryFunc GetBuiltinUnaryMethod() { + return &This::MakeBuiltinUnaryFunc<&Impl::UnaryCall>; + } + + static constexpr BuiltinUnaryFunc GetBuiltinHighOrderUnaryMethod() { + return &This::MakeBuiltinHighOrderUnaryFunc<&Impl::HighOrderUnaryCall>; + } + + static BuiltinUnaryFunc GetBuiltinUnaryFunc() { + static const MethodClassImpl + detect_specialization_of_method_class_impl; + (void)detect_specialization_of_method_class_impl; + if constexpr (HasUnaryMethod()) { + return GetBuiltinUnaryMethod(); + } else if constexpr (HasHighOrderUnaryMethod()) { + return GetBuiltinHighOrderUnaryMethod(); + } else if constexpr (HasDefaultUnaryMethod()) { + return MethodClassImpl::template GetBuiltinUnaryFunc(); + } else { + return adt::Nothing{}; + } + } + + template + using UnaryMethodRetT = + decltype(MethodClassImpl::template GetBuiltinUnaryFunc< + BuiltinSymbol>()); + + static constexpr bool HasDefaultUnaryMethod() { + return std::experimental::is_detected_v; + } + + static BuiltinBinaryFunc GetBuiltinBinaryFunc() { + static const MethodClassImpl + detect_specialization_of_method_class_impl; + (void)detect_specialization_of_method_class_impl; + if constexpr (Impl::HasBinaryMethod()) { + return &This::MakeBuiltinBinaryFunc<&Impl::BinaryCall>; + } else if constexpr (Impl::HasHighOrderBinaryMethod()) { + return &This::MakeBuiltinHighOrderBinaryFunc<&Impl::HighOrderBinaryCall>; + } else if constexpr (HasDefaultBinaryMethod()) { + return MethodClassImpl::template GetBuiltinBinaryFunc(); + } else { + return adt::Nothing{}; + } + } + + template + using BinaryMethodRetT = + decltype(MethodClassImpl::template GetBuiltinBinaryFunc< + BuiltinSymbol>()); + + static constexpr bool HasDefaultBinaryMethod() { + return std::experimental::is_detected_v; + } + + template (*UnaryFunc)(const T&)> + static adt::Result MakeBuiltinUnaryFunc(const ValueT& obj_val) { + ADT_LET_CONST_REF(obj, Alternative::TryGet(obj_val)); + const auto& ret = UnaryFunc(obj); + return ret; + } + + template (*UnaryFunc)(InterpreterBase*, + const T&)> + static adt::Result MakeBuiltinHighOrderUnaryFunc( + InterpreterBase* interpreter, const ValueT& obj_val) { + ADT_LET_CONST_REF(obj, Alternative::TryGet(obj_val)); + const auto& ret = UnaryFunc(interpreter, obj); + return ret; + } + + template (*BinaryFunc)(const T&, const ValueT&)> + static adt::Result MakeBuiltinBinaryFunc(const ValueT& obj_val, + const ValueT& arg) { + ADT_LET_CONST_REF(obj, Alternative::TryGet(obj_val)); + return BinaryFunc(obj, arg); + } + template (*BinaryFunc)( + InterpreterBase*, const T&, const ValueT&)> + static adt::Result MakeBuiltinHighOrderBinaryFunc( + InterpreterBase* interpreter, + const ValueT& obj_val, + const ValueT& arg) { + ADT_LET_CONST_REF(obj, Alternative::TryGet(obj_val)); + return BinaryFunc(interpreter, obj, arg); + } +}; + +} // namespace detail + +template +struct MethodClass { + using This = MethodClass; + + static BuiltinUnaryFunc Hash(const ValueT& val) { + using S = builtin_symbol::Hash; + return val.Match([](const auto& impl) -> BuiltinUnaryFunc { + using T = std::decay_t; + if constexpr (IsType()) { + return impl.Match([](const auto& type_impl) + -> BuiltinUnaryFunc { + using TT = std::decay_t; + using Helper = detail:: + BuiltinMethodHelper; + if constexpr (Helper::HasUnaryMethod()) { + return Helper::GetBuiltinUnaryMethod(); + } else if constexpr (Helper::HasHighOrderUnaryMethod()) { + return Helper::GetBuiltinHighOrderUnaryMethod(); + } else { + return &This::TypeDefaultHash; + } + }); + } else { + using Helper = detail:: + BuiltinMethodHelper; + if constexpr (Helper::HasUnaryMethod()) { + return Helper::GetBuiltinUnaryMethod(); + } else if constexpr (Helper::HasHighOrderUnaryMethod()) { + return Helper::GetBuiltinHighOrderUnaryMethod(); + } else { + return &This::InstanceDefaultHash; + } + } + }); + } + + template + static adt::Result TypeDefaultHash(const ValueT& val) { + int64_t hash_value = std::hash()(typeid(TT).name()); + return hash_value; + } + + template + static adt::Result InstanceDefaultHash(const ValueT& val) { + ADT_LET_CONST_REF(impl, val.template TryGet()); + // please implement MethodClassImpl::Hash if T is not defined + // by ADT_DEFINE_RC. + const void* ptr = impl.__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + static BuiltinUnaryFunc ToString(const ValueT& val) { + using S = builtin_symbol::ToString; + return val.Match([](const auto& impl) -> BuiltinUnaryFunc { + using T = std::decay_t; + if constexpr (IsType()) { + return impl.Match([](const auto& type_impl) + -> BuiltinUnaryFunc { + using TT = std::decay_t; + using Helper = detail:: + BuiltinMethodHelper; + if constexpr (Helper::HasUnaryMethod()) { + return Helper::GetBuiltinUnaryMethod(); + } else if constexpr (Helper::HasHighOrderUnaryMethod()) { + return Helper::GetBuiltinHighOrderUnaryMethod(); + } else { + return &This::TypeDefaultToString; + } + }); + } else { + using Helper = detail:: + BuiltinMethodHelper; + if constexpr (Helper::HasUnaryMethod()) { + return Helper::GetBuiltinUnaryMethod(); + } else if constexpr (Helper::HasHighOrderUnaryMethod()) { + return Helper::GetBuiltinHighOrderUnaryMethod(); + } else { + return &This::InstanceDefaultToString; + } + } + }); + } + + template + static adt::Result TypeDefaultToString(const ValueT& val) { + std::ostringstream ss; + ss << ""; + return ss.str(); + } + + template + static adt::Result InstanceDefaultToString(const ValueT& val) { + std::ostringstream ss; + ADT_LET_CONST_REF(impl, val.template TryGet()); + // please implement MethodClassImpl::ToString if T is not defined + // by ADT_DEFINE_RC. + const void* ptr = impl.__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << TypeImpl{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + template + static BuiltinUnaryFunc GetBuiltinUnaryFunc(const ValueT& val) { + using S = BuiltinUnarySymbol; + return val.Match([](const auto& impl) -> BuiltinUnaryFunc { + using T = std::decay_t; + if constexpr (IsType()) { + return impl.Match([](const auto& type_impl) + -> BuiltinUnaryFunc { + using TT = std::decay_t; + using Helper = detail:: + BuiltinMethodHelper; + return Helper::GetBuiltinUnaryFunc(); + }); + } else { + using Helper = detail:: + BuiltinMethodHelper; + return Helper::GetBuiltinUnaryFunc(); + } + }); + } + + template + static BuiltinBinaryFunc GetBuiltinBinaryFunc(const ValueT& val) { + using S = BuiltinBinarySymbol; + return val.Match([](const auto& impl) -> BuiltinBinaryFunc { + using T = std::decay_t; + if constexpr (IsType()) { + return impl.Match([](const auto& type_impl) + -> BuiltinBinaryFunc { + using TT = std::decay_t; + using Helper = detail:: + BuiltinMethodHelper; + return Helper::GetBuiltinBinaryFunc(); + }); + } else { + using Helper = detail:: + BuiltinMethodHelper; + return Helper::GetBuiltinBinaryFunc(); + } + }); + } +}; + +template +using __AltT = decltype(std::declval().template Get()); + +template +adt::Result TryGetAlternative(const ValueT& val) { + if constexpr (std::experimental::is_detected_v<__AltT, ValueT, T>) { + return val.template TryGet(); + } else { + return detail::IndirectAlternative::TryGet(val); + } +} + +template +adt::Result TryGetImpl(const ValueT& val) { + return TryGetAlternative(val); +} + +template +std::string GetTypeName(const ValueT& val) { + return val.Match( + [](const BuiltinClassInstance& impl) -> std::string { + return impl.type.class_attrs()->class_name; + }, + [](const ClassInstance& impl) -> std::string { + return impl->type.class_attrs->class_name; + }, + [](const auto& impl) -> std::string { + using T = std::decay_t; + return TypeImpl{}.Name(); + }); +} + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/method_method_class.h b/paddle/ap/include/axpr/method_method_class.h new file mode 100644 index 00000000000000..5e54bf0bc0cad1 --- /dev/null +++ b/paddle/ap/include/axpr/method_method_class.h @@ -0,0 +1,57 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/method.h" +#include "paddle/ap/include/axpr/method_class.h" + +namespace ap::axpr { + +template +struct MethodMethodClass { + using Self = MethodMethodClass; + + template + static BuiltinUnaryFunc GetBuiltinUnaryFunc() { + return adt::Nothing{}; + } + + template + static BuiltinBinaryFunc GetBuiltinBinaryFunc() { + return adt::Nothing{}; + } +}; + +template +struct MethodClassImpl> { + using method_class = MethodMethodClass; + + template + static BuiltinUnaryFunc GetBuiltinUnaryFunc() { + return method_class::template GetBuiltinUnaryFunc(); + } + + template + static BuiltinBinaryFunc GetBuiltinBinaryFunc() { + return method_class::template GetBuiltinBinaryFunc(); + } +}; + +template +struct MethodClassImpl>> + : public EmptyMethodClass {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/module_mgr.h b/paddle/ap/include/axpr/module_mgr.h new file mode 100644 index 00000000000000..4343103890dd42 --- /dev/null +++ b/paddle/ap/include/axpr/module_mgr.h @@ -0,0 +1,199 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "glog/logging.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/builtin_func_name_mgr.h" +#include "paddle/ap/include/axpr/frame.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/env/ap_path.h" +#include "paddle/ap/include/memory/guard.h" +#include "paddle/ap/include/preprocessor/preprocessor.h" + +namespace ap::axpr { + +class ModuleMgr { + public: + ModuleMgr() + : memory_guard_(), + file_path2const_global_frame_(), + module_name2const_global_frame_() {} + + ModuleMgr(const ModuleMgr&) = delete; + ModuleMgr(ModuleMgr&&) = delete; + + static ModuleMgr* Singleton() { + static ModuleMgr module_mgr; + return &module_mgr; + } + + std::optional> OptGetBuiltinModule( + const std::string& module_name) { + const auto iter = module_name2builtin_module_.find(module_name); + if (iter == module_name2builtin_module_.end()) return std::nullopt; + return iter->second; + } + + template + adt::Result> GetOrCreateByModuleName( + const std::string& module_name, const InitT& Init) { + { + const auto& iter = module_name2const_global_frame_.find(module_name); + if (iter != module_name2const_global_frame_.end()) { + return iter->second; + } + } + ADT_LET_CONST_REF(file_path, GetFilePathByModuleName(module_name)) + << adt::errors::ModuleNotFoundError{ + std::string() + "No module named '" + module_name + "'"}; + ADT_LET_CONST_REF(frame, GetOrCreateByFilePath(file_path, Init)); + ADT_CHECK( + module_name2const_global_frame_.emplace(module_name, frame).second); + return frame; + } + + template + adt::Result> GetOrCreateByFilePath( + const std::string& file_path, const InitT& Init) { + const auto& iter = file_path2const_global_frame_.find(file_path); + if (iter != file_path2const_global_frame_.end()) { + return iter->second; + } + auto frame_object = std::make_shared>(); + const auto& frame = + Frame::Make(circlable_ref_list(), frame_object); + ADT_LET_CONST_REF(lambda, GetLambdaByFilePath(file_path)); + ADT_CHECK(file_path2const_global_frame_.emplace(file_path, frame).second); + ADT_RETURN_IF_ERR(Init(frame, lambda)); + return frame; + } + + const std::shared_ptr& circlable_ref_list() + const { + return memory_guard_.circlable_ref_list(); + } + + void RegisterBuiltinFrame(const std::string& name, + const axpr::AttrMap& attr_map) { + CHECK(module_name2builtin_module_.emplace(name, attr_map).second); + } + + private: + adt::Result GetFilePathByModuleName( + const std::string& module_name) { + std::optional file_path; + using RetT = adt::Result; + ADT_RETURN_IF_ERR( + VisitEachConfigFilePath([&](const std::string& dir_name) -> RetT { + const std::string& cur_file_path = + dir_name + "/" + module_name + ".py.json"; + if (FileExists(cur_file_path)) { + file_path = cur_file_path; + return adt::Break{}; + } else { + return adt::Continue{}; + } + })); + ADT_CHECK(file_path.has_value()); + return file_path.value(); + } + + adt::Result> GetLambdaByFilePath( + const std::string& file_path) { + ADT_LET_CONST_REF(file_content, GetFileContent(file_path)); + ADT_CHECK(!file_content.empty()); + ADT_LET_CONST_REF(anf_expr, axpr::MakeAnfExprFromJsonString(file_content)); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + std::vector> args{}; + axpr::Lambda lambda{args, core_expr}; + return lambda; + } + + adt::Result GetFileContent(const std::string& filepath) { + std::ifstream ifs(filepath); + std::string content{std::istreambuf_iterator(ifs), + std::istreambuf_iterator()}; + return content; + } + + bool FileExists(const std::string& filepath) { + std::fstream fp; + fp.open(filepath, std::fstream::in); + if (fp.is_open()) { + fp.close(); + return true; + } else { + return false; + } + } + + template + adt::Result VisitEachConfigFilePath(const DoEachT& DoEach) { + return ap::env::VisitEachApPath(DoEach); + } + + memory::Guard memory_guard_; + + std::unordered_map> + file_path2const_global_frame_; + + std::unordered_map> + module_name2const_global_frame_; + + std::unordered_map> + module_name2builtin_module_; +}; + +struct ApBuiltinModuleBuilder { + std::string module_name; + axpr::AttrMap attr_map; + + void Def(const std::string& name, + const axpr::BuiltinFuncType& func) { + void* func_ptr = reinterpret_cast(func); + attr_map->Set(name, BuiltinFuncVoidPtr{func_ptr}); + BuiltinFuncNameMgr::Singleton()->Register(module_name, name, func_ptr); + } + + void Def(const std::string& name, + const axpr::BuiltinHighOrderFuncType& func) { + void* func_ptr = reinterpret_cast(func); + attr_map->Set(name, BuiltinHighOrderFuncVoidPtr{func_ptr}); + BuiltinFuncNameMgr::Singleton()->Register(module_name, name, func_ptr); + } +}; + +struct ApBuiltinModuleRegistryHelper { + ApBuiltinModuleRegistryHelper( + const std::string& name, + const std::function& func) { + ApBuiltinModuleBuilder builder{name}; + func(&builder); + ModuleMgr::Singleton()->RegisterBuiltinFrame(name, builder.attr_map); + } +}; + +#define REGISTER_AP_BUILTIN_MODULE(name, ...) \ + namespace { \ + ::ap::axpr::ApBuiltinModuleRegistryHelper AP_CONCAT( \ + ap_builtin_module_registry_helper, __LINE__)(name, __VA_ARGS__); \ + } + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/module_mgr_helper.h b/paddle/ap/include/axpr/module_mgr_helper.h new file mode 100644 index 00000000000000..6d45d692933798 --- /dev/null +++ b/paddle/ap/include/axpr/module_mgr_helper.h @@ -0,0 +1,56 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/interpreter_base.h" +#include "paddle/ap/include/axpr/module_mgr.h" + +namespace ap::axpr { + +template +struct ModuleMgrHelper { + using This = ModuleMgrHelper; + + static adt::Result ImportModule(InterpreterBase* interpreter, + const ValueT&, + const std::vector& args) { + return This{}.Import(interpreter, args); + } + + adt::Result Import(InterpreterBase* interpreter, + const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(module_name, args.at(0).template TryGet()); + auto* module_mgr = ModuleMgr::Singleton(); + const auto& opt_builtin_module = + module_mgr->OptGetBuiltinModule(module_name); + if (opt_builtin_module.has_value()) { + return opt_builtin_module.value(); + } + auto Init = [&](const Frame& frame, + const axpr::Lambda& lambda) + -> adt::Result { + ADT_RETURN_IF_ERR(interpreter->InterpretModule(frame, lambda)); + return adt::Ok{}; + }; + ADT_LET_CONST_REF(frame, + module_mgr->GetOrCreateByModuleName(module_name, Init)); + ADT_LET_CONST_REF(frame_impl_obj, frame.shared_ptr()); + return axpr::AttrMap{frame_impl_obj}; + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/mutable_global_environment.h b/paddle/ap/include/axpr/mutable_global_environment.h new file mode 100644 index 00000000000000..c4df8a0a4dc031 --- /dev/null +++ b/paddle/ap/include/axpr/mutable_global_environment.h @@ -0,0 +1,100 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/environment.h" +#include "paddle/ap/include/axpr/frame.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/axpr/serializable_value_helper.h" + +namespace ap::axpr { + +template +class MutableGlobalEnvironment : public Environment { + public: + MutableGlobalEnvironment(const std::shared_ptr>& parent, + const Frame& const_frame, + const Frame& temp_frame) + : parent_(parent), const_frame_(const_frame), temp_frame_(temp_frame) {} + + adt::Result Get(const std::string& var) const override { + if (IsTempVar(var)) { + ADT_LET_CONST_REF(temp_frame_ptr, temp_frame_.Get()); + const auto& val_in_temp_frame = temp_frame_ptr->OptGet(var); + if (val_in_temp_frame.has_value()) { + return val_in_temp_frame.value(); + } + } else { + ADT_LET_CONST_REF(const_frame_ptr, const_frame_.Get()); + const auto& val_in_const_frame = const_frame_ptr->OptGet(var); + if (val_in_const_frame.has_value()) { + return val_in_const_frame.value().template CastTo(); + } + } + if (parent_ == nullptr) { + return NameError{std::string("name '") + var + "' is not defined."}; + } + return parent_->Get(var); + } + + adt::Result Set(const std::string& var, const ValueT& val) override { + if (IsTempVar(var)) { + ADT_LET_CONST_REF(temp_frame_ptr, temp_frame_.Mut()); + temp_frame_ptr->Set(var, val); + } else { + ADT_LET_CONST_REF(const_frame_ptr, const_frame_.Mut()); + SerializableValueHelper helper{}; + ADT_LET_CONST_REF(serializable_val, helper.CastFrom(val)) << [&] { + std::ostringstream ss; + ss << "Only serializable values are supported insert into global " + "environment. "; + ss << "Builtin serializable types are: "; + ss << SerializableValue::SerializableTypeNames(); + ss << " (not include '" << axpr::GetTypeName(val) << "')."; + return adt::errors::ValueError{ss.str()}; + }(); + const_frame_ptr->Set(var, serializable_val); + } + return adt::Ok{}; + } + + bool IsTempVar(const std::string& var) const { + static std::string tmp_var_prefix("__"); + return var.substr(0, tmp_var_prefix.size()) == tmp_var_prefix; + } + + std::optional> GetConstGlobalFrame() const override { + return const_frame_; + } + + std::optional> RecursivelyGetConstGlobalFrame() + const override { + return const_frame_; + } + + const Frame& temp_frame() const { return temp_frame_; } + + private: + MutableGlobalEnvironment(const MutableGlobalEnvironment&) = delete; + MutableGlobalEnvironment(MutableGlobalEnvironment&&) = delete; + + std::shared_ptr> parent_; + Frame const_frame_; + Frame temp_frame_; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/mutable_list.h b/paddle/ap/include/axpr/mutable_list.h new file mode 100644 index 00000000000000..de5b4fcfe7e131 --- /dev/null +++ b/paddle/ap/include/axpr/mutable_list.h @@ -0,0 +1,38 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/memory/circlable_ref.h" + +namespace ap::axpr { + +template +struct MutableList + : public memory::CirclableRef, std::vector> { + using Base = memory::CirclableRef, std::vector>; + using Base::CirclableRef; +}; + +template +struct TypeImpl> : public std::monostate { + using value_type = MutableList; + + const char* Name() const { return "MutableList"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/mutable_list_method_class.h b/paddle/ap/include/axpr/mutable_list_method_class.h new file mode 100644 index 00000000000000..2c0cac32b16d8c --- /dev/null +++ b/paddle/ap/include/axpr/mutable_list_method_class.h @@ -0,0 +1,180 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/method.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/mutable_list.h" +#include "paddle/ap/include/axpr/starred.h" + +namespace ap::axpr { + +template +struct MethodClassImpl> { + using This = MethodClassImpl>; + using Self = MutableList; + + adt::Result Length(const Self& self) { + ADT_LET_CONST_REF(vec, self.Get()); + return static_cast(vec->size()); + } + + adt::Result ToString(axpr::InterpreterBase* interpreter, + const Self& self) { + ADT_LET_CONST_REF(vec, self.Get()); + std::ostringstream ss; + ss << "["; + int i = 0; + using Ok = adt::Result; + for (const auto& elt : *vec) { + if (i++ > 0) { + ss << ", "; + } + const auto& func = MethodClass::ToString(elt); + ADT_RETURN_IF_ERR(func.Match( + [&](const adt::Nothing&) -> Ok { + return adt::errors::TypeError{GetTypeName(elt) + + " class has no __str__ function"}; + }, + [&](adt::Result (*unary_func)(const ValueT&)) -> Ok { + ADT_LET_CONST_REF(str_val, unary_func(elt)); + ADT_LET_CONST_REF(str, str_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + "'" + axpr::GetTypeName(elt) + + ".__builtin_ToString__ should return a 'str' but '" + + axpr::GetTypeName(str_val) + "' were returned."}; + ss << str; + return adt::Ok{}; + }, + [&](adt::Result (*unary_func)(axpr::InterpreterBase*, + const ValueT&)) -> Ok { + ADT_LET_CONST_REF(str_val, unary_func(interpreter, elt)); + ADT_LET_CONST_REF(str, str_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + "'" + axpr::GetTypeName(elt) + + ".__builtin_ToString__ should return a 'str' but '" + + axpr::GetTypeName(str_val) + "' were returned."}; + ss << str; + return adt::Ok{}; + })); + } + ss << "]"; + return ss.str(); + } + + adt::Result Hash(const Self& self) { + return adt::errors::TypeError{"MutableList objects are not hashable"}; + } + + adt::Result GetItem(const Self& self, const ValueT& idx) { + ADT_LET_CONST_REF(vec, self.Get()); + return idx.Match( + [&](int64_t index) -> Result { + if (index < 0) { + index += vec->size(); + } + if (index >= 0 && index < vec->size()) { + return vec->at(index); + } + return adt::errors::IndexError{"list index out of range"}; + }, + [&](const auto&) -> Result { + return adt::errors::TypeError{std::string() + + "list indices must be integers, not " + + axpr::GetTypeName(idx)}; + }); + } + + adt::Result SetItem(const Self& self, const ValueT& idx) { + return Method{self, &This::StaticSetItem}; + } + + static adt::Result StaticSetItem(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template TryGet()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(const_idx, args.at(0).template TryGet()) + << adt::errors::TypeError{ + std::string() + "list indices must be integers or slices, not " + + axpr::GetTypeName(args.at(0))}; + ADT_LET_CONST_REF(self_ptr, self.Mut()); + int64_t idx = const_idx; + if (idx < 0) { + idx += self_ptr->size(); + } + ADT_CHECK(idx < self_ptr->size()) + << adt::errors::IndexError{"list index out of range"}; + self_ptr->at(idx) = args.at(1); + return adt::Nothing{}; + } + + adt::Result GetAttr(const Self& self, const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + if (attr_name == "append") { + return Method{self, &This::StaticAppend}; + } + return adt::errors::AttributeError{ + std::string() + "'MutableList' object has no attribute '" + attr_name + + "'"}; + } + + static adt::Result StaticAppend(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template TryGet()); + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "MutableList.append() takes exactly one argument (" + + std::to_string(args.size()) + " given)"}; + ADT_LET_CONST_REF(self_ptr, self.Mut()); + self_ptr->push_back(args.at(0)); + return adt::Nothing{}; + } + + adt::Result Starred(const Self& self) { + ADT_LET_CONST_REF(vec, self.Get()); + adt::List ret{}; + ret->reserve(vec->size()); + ret->assign(vec->begin(), vec->end()); + return ap::axpr::Starred{ret}; + } +}; + +template +struct MethodClassImpl>> { + using Self = TypeImpl>; + using This = MethodClassImpl; + + adt::Result Call(const Self&) { return &This::StaticConstruct; } + + static adt::Result StaticConstruct( + axpr::InterpreterBase* interpreter, + const ValueT&, + const std::vector& args) { + return This{}.Construct(interpreter, args); + } + + adt::Result Construct(axpr::InterpreterBase* interpreter, + const std::vector& args) { + ADT_LET_CONST_REF(ref_lst, + adt::WeakPtrLock(interpreter->circlable_ref_list())); + const auto& mut_list = MutableList::Make( + ref_lst, std::make_shared>()); + ADT_LET_CONST_REF(ptr, mut_list.Mut()); + *ptr = args; + return mut_list; + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/mutable_ordered_dict.h b/paddle/ap/include/axpr/mutable_ordered_dict.h new file mode 100644 index 00000000000000..5ec74f9548e2ce --- /dev/null +++ b/paddle/ap/include/axpr/mutable_ordered_dict.h @@ -0,0 +1,42 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/ordered_dict.h" +#include "paddle/ap/include/memory/circlable_ref.h" + +namespace ap::axpr { + +template +using MutableOrderedDictImpl = + OrderedDictImpl>; + +template +struct MutableOrderedDict + : public memory::CirclableRef, + MutableOrderedDictImpl> { + using Base = memory::CirclableRef, + MutableOrderedDictImpl>; + using Base::CirclableRef; +}; + +template +struct TypeImpl> : public std::monostate { + using value_type = MutableOrderedDict; + + const char* Name() const { return "MutableOrderedDict"; } +}; + +}; // namespace ap::axpr diff --git a/paddle/ap/include/axpr/mutable_ordered_dict_method_class.h b/paddle/ap/include/axpr/mutable_ordered_dict_method_class.h new file mode 100644 index 00000000000000..adc97e063f11dc --- /dev/null +++ b/paddle/ap/include/axpr/mutable_ordered_dict_method_class.h @@ -0,0 +1,184 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/builtin_func_type.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/mutable_ordered_dict.h" +#include "paddle/ap/include/axpr/to_string.h" + +namespace ap::axpr { + +template +struct MethodClassImpl> { + using Val = ValueT; + using This = MethodClassImpl>; + using Self = MutableOrderedDict; + + adt::Result Length(const Self& self) { + ADT_LET_CONST_REF(self_ptr, self.Get()); + return static_cast(self_ptr->items().size()); + } + + adt::Result ToString(axpr::InterpreterBase* interpreter, + const Self& self) { + ADT_LET_CONST_REF(self_ptr, self.Get()); + std::ostringstream ss; + ss << "MutableOrderedDict(["; + int i = 0; + for (const auto& [k, v] : self_ptr->items()) { + if (i++ > 0) { + ss << ", "; + } + ADT_LET_CONST_REF(key_str, axpr::ToString(interpreter, k)); + ADT_LET_CONST_REF(value_str, axpr::ToString(interpreter, v)); + ss << "[" << key_str << ", " << value_str << "]"; + } + ss << "])"; + return ss.str(); + } + + adt::Result Hash(axpr::InterpreterBase* interpreter, + const Self& self) { + ADT_LET_CONST_REF(self_ptr, self.Get()); + int64_t hash_value = 0; + for (const auto& [k, v] : self_ptr->items()) { + ADT_LET_CONST_REF(key_hash_value, axpr::Hash{}(interpreter, k)); + ADT_LET_CONST_REF(value_hash_value, axpr::Hash{}(interpreter, v)); + hash_value = adt::hash_combine(hash_value, key_hash_value); + hash_value = adt::hash_combine(hash_value, value_hash_value); + } + return hash_value; + } + + adt::Result GetItem(axpr::InterpreterBase* interpreter, + const Self& self, + const ValueT& key) { + ADT_LET_CONST_REF(self_ptr, self.Get()); + ADT_LET_CONST_REF(val, self_ptr->At(interpreter, key)) + << adt::errors::KeyError{axpr::ToDebugString(interpreter, key)}; + return val; + } + + adt::Result GetAttr(const Self& self, const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + if (attr_name == "items") { + return axpr::Method{ + self, &axpr::WrapAsBuiltinFuncType}; + } + if (attr_name == "contains") { + return axpr::Method{self, &This::Contains}; + } + if (attr_name == "get_or_create") { + return axpr::Method{self, &This::GetOrCreate}; + } + return adt::errors::TypeError{ + std::string() + "MutableOrderedDict object has no attribute '" + + attr_name + "'"}; + } + + static adt::Result GetOrCreate(InterpreterBase* interpreter, + const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_LET_CONST_REF(self_ptr, self.Mut()); + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string() + + "MutableOrderedDict.get_or_create() takes 2 argument, but " + + std::to_string(args.size()) + " were given"}; + const auto& key = args.at(0); + ADT_LET_CONST_REF(has_key, self_ptr->Has(interpreter, key)); + if (!has_key) { + ADT_LET_CONST_REF(val, interpreter->InterpretCall(args.at(1), {})); + ADT_RETURN_IF_ERR(self_ptr->Insert(interpreter, key, val)); + return val; + } else { + ADT_LET_CONST_REF(val, self_ptr->At(interpreter, key)) + << adt::errors::KeyError{axpr::ToDebugString(interpreter, key)}; + return val; + } + } + + static adt::Result Contains(InterpreterBase* interpreter, + const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_LET_CONST_REF(self_ptr, self.Get()); + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "MutableOrderedDict.contains() takes 1 argument, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(has_elt, self_ptr->Has(interpreter, args.at(0))); + return has_elt; + } + + adt::Result Items(const Self& self, const std::vector&) { + ADT_LET_CONST_REF(self_ptr, self.Get()); + adt::List lst; + lst->reserve(self_ptr->items().size()); + for (const auto& [first, second] : self_ptr->items()) { + lst->emplace_back(adt::List{first, second}); + } + return lst; + } +}; + +template +struct MethodClassImpl>> { + using Val = ValueT; + using Self = TypeImpl>; + using This = MethodClassImpl; + + adt::Result Call(const Self& self) { + return axpr::Method{self, &This::Construct}; + } + + static adt::Result Construct(InterpreterBase* interpreter, + const ValueT&, + const std::vector& args) { + auto impl = std::make_shared>(); + ADT_LET_CONST_REF(circlable_ref_list, + adt::WeakPtrLock(interpreter->circlable_ref_list())); + auto ordered_dict = + MutableOrderedDict::Make(circlable_ref_list, impl); + if (args.size() == 0) { + return ordered_dict; + } + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "MutableOrderedDict() takes 1 argument but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(lst, args.at(0).template TryGet>()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of MutableOrderedDict() should be list, " + + axpr::GetTypeName(args.at(0)) + " found."}; + int i = 0; + for (const auto& elt : *lst) { + ADT_LET_CONST_REF(pair, elt.template TryGet>()) + << adt::errors::TypeError{std::string() + "sequence item " + + std::to_string(i) + + " : expected list instance, " + + axpr::GetTypeName(elt) + " found."}; + ADT_CHECK(pair->size() == 2) << adt::errors::TypeError{ + std::string() + "sequence item " + std::to_string(i) + + " : expected 2-item list, " + std::to_string(pair->size()) + + "-item list found."}; + ADT_RETURN_IF_ERR(impl->Insert(interpreter, pair->at(0), pair->at(1))); + ++i; + } + return ordered_dict; + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/naive_class_ops.h b/paddle/ap/include/axpr/naive_class_ops.h new file mode 100644 index 00000000000000..5a6f1b8728fe61 --- /dev/null +++ b/paddle/ap/include/axpr/naive_class_ops.h @@ -0,0 +1,61 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/class_ops.h" +#include "paddle/ap/include/axpr/value.h" + +namespace ap::axpr { + +template +class NaiveClassOps : public ClassOps { + public: + explicit NaiveClassOps(const ClassAttrs& class_attrs) + : class_attrs_(class_attrs) {} + + using This = NaiveClassOps; + + const ClassAttrsImpl* class_attrs() const override { + return class_attrs_.shared_ptr().get(); + } + + adt::Result Equals(const axpr::Value& lhs_val, + const axpr::Value& rhs_val) const override { + return EqualsImpl(lhs_val, rhs_val); + } + + private: + static adt::Result EqualsImpl(const axpr::Value& lhs_val, + const axpr::Value& rhs_val) { + ADT_LET_CONST_REF(lhs, lhs_val.template CastTo()); + if (!rhs_val.template CastableTo()) { + return false; + } + ADT_LET_CONST_REF(rhs, rhs_val.template CastTo()); + return lhs == rhs; + } + + const ClassAttrs class_attrs_; +}; + +template +class ClassOps* MakeGlobalNaiveClassOps( + const ClassAttrs& class_attrs) { + static NaiveClassOps naive_class_ops(class_attrs); + return &naive_class_ops; +} + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/nothing.h b/paddle/ap/include/axpr/nothing.h new file mode 100644 index 00000000000000..dff32e0943f0b2 --- /dev/null +++ b/paddle/ap/include/axpr/nothing.h @@ -0,0 +1,30 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template <> +struct TypeImpl : public std::monostate { + using value_type = adt::Nothing; + + const char* Name() const { return "NoneType"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/nothing_method_class.h b/paddle/ap/include/axpr/nothing_method_class.h new file mode 100644 index 00000000000000..0a402ec6ad411b --- /dev/null +++ b/paddle/ap/include/axpr/nothing_method_class.h @@ -0,0 +1,54 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/method_class.h" + +namespace ap::axpr { + +template +struct NothingMethodClass { + using This = NothingMethodClass; + using Self = adt::Nothing; + + adt::Result ToString(const Self&) { return std::string(""); } + + adt::Result Hash(const Self& self) { return static_cast(0); } + + Result EQ(const ValueT& lhs_val, const ValueT& rhs_val) { + const auto& opt_lhs = lhs_val.template TryGet(); + ADT_RETURN_IF_ERR(opt_lhs); + return rhs_val.Match([](adt::Nothing) -> ValueT { return true; }, + [](const auto&) -> ValueT { return false; }); + } + + Result NE(const ValueT& lhs_val, const ValueT& rhs_val) { + const auto& opt_lhs = lhs_val.template TryGet(); + ADT_RETURN_IF_ERR(opt_lhs); + return rhs_val.Match([](adt::Nothing) -> ValueT { return false; }, + [](const auto&) -> ValueT { return true; }); + } +}; + +template +struct MethodClassImpl + : public NothingMethodClass {}; + +template +struct MethodClassImpl> + : public EmptyMethodClass {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/ordered_dict.h b/paddle/ap/include/axpr/ordered_dict.h new file mode 100644 index 00000000000000..451ca2cdf6ac0e --- /dev/null +++ b/paddle/ap/include/axpr/ordered_dict.h @@ -0,0 +1,107 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/hash.h" +#include "paddle/ap/include/axpr/interpreter_base.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +struct OrderedDictImpl { + public: + OrderedDictImpl() {} + + bool operator==(const OrderedDictImpl& other) const { return this == &other; } + + using ItemT = std::pair; + + const std::list& items() const { return items_; } + + adt::Result Has(InterpreterBase* interpreter, + const KeyT& key) const { + Hasher hasher{}; + ADT_LET_CONST_REF(hash_value, hasher(interpreter, key)); + const auto& iter_to_iters = this->hash_value2pair_iters_.find(hash_value); + if (iter_to_iters == this->hash_value2pair_iters_.end()) { + return false; + } + for (auto iter : iter_to_iters->second) { + if (iter->first == key) { + return true; + } + } + return false; + } + + adt::Result At(InterpreterBase* interpreter, + const KeyT& key) const { + Hasher hasher{}; + ADT_LET_CONST_REF(hash_value, hasher(interpreter, key)); + const auto& iter_to_iters = this->hash_value2pair_iters_.find(hash_value); + ADT_CHECK(iter_to_iters != this->hash_value2pair_iters_.end()); + for (auto iter : iter_to_iters->second) { + if (iter->first == key) { + return iter->second; + } + } + return adt::errors::KeyError{"OrderedDictImpl::At() failed."}; + } + + adt::Result Insert(InterpreterBase* interpreter, + const ItemT& pair) { + return Insert(interpreter, pair.first, pair.second); + } + + adt::Result Insert(InterpreterBase* interpreter, + const ValueT& key, + const ValueT& val) { + Hasher hasher{}; + ADT_LET_CONST_REF(hash_value, hasher(interpreter, key)); + auto* lst = &this->hash_value2pair_iters_[hash_value]; + for (auto iter : *lst) { + if (iter->first == key) { + iter->second = val; + return adt::Ok{}; + } + } + lst->emplace_back( + this->items_.insert(this->items_.end(), std::pair{key, val})); + return adt::Ok{}; + } + + private: + using ItemsT = std::list; + ItemsT items_; + std::unordered_map> + hash_value2pair_iters_; +}; + +template +ADT_DEFINE_RC(OrderedDict, OrderedDictImpl>); + +template +struct TypeImpl> : public std::monostate { + using std::monostate::monostate; + + const char* Name() const { return "OrderedDict"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/ordered_dict_method_class.h b/paddle/ap/include/axpr/ordered_dict_method_class.h new file mode 100644 index 00000000000000..429315731480a6 --- /dev/null +++ b/paddle/ap/include/axpr/ordered_dict_method_class.h @@ -0,0 +1,152 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/builtin_func_type.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/ordered_dict.h" +#include "paddle/ap/include/axpr/to_string.h" + +namespace ap::axpr { + +template +struct MethodClassImpl> { + using Val = ValueT; + using This = MethodClassImpl>; + using Self = OrderedDict; + + adt::Result Length(const Self& self) { + return static_cast(self->items().size()); + } + + adt::Result ToString(axpr::InterpreterBase* interpreter, + const Self& self) { + std::ostringstream ss; + ss << "OrderedDict(["; + int i = 0; + for (const auto& [k, v] : self->items()) { + if (i++ > 0) { + ss << ", "; + } + ADT_LET_CONST_REF(key_str, axpr::ToString(interpreter, k)); + ADT_LET_CONST_REF(value_str, axpr::ToString(interpreter, v)); + ss << "[" << key_str << ", " << value_str << "]"; + } + ss << "])"; + return ss.str(); + } + + adt::Result Hash(axpr::InterpreterBase* interpreter, + const Self& self) { + int64_t hash_value = 0; + for (const auto& [k, v] : self->items()) { + ADT_LET_CONST_REF(key_hash_value, axpr::Hash{}(interpreter, k)); + ADT_LET_CONST_REF(value_hash_value, axpr::Hash{}(interpreter, v)); + hash_value = adt::hash_combine(hash_value, key_hash_value); + hash_value = adt::hash_combine(hash_value, value_hash_value); + } + return hash_value; + } + + adt::Result GetItem(axpr::InterpreterBase* interpreter, + const Self& self, + const ValueT& key) { + ADT_LET_CONST_REF(val, self->At(interpreter, key)) + << adt::errors::KeyError{axpr::ToDebugString(interpreter, key)}; + return val; + } + + adt::Result GetAttr(const Self& self, const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + if (attr_name == "items") { + return axpr::Method{ + self, &axpr::WrapAsBuiltinFuncType}; + } + if (attr_name == "contains") { + return axpr::Method{self, &This::Contains}; + } + return adt::errors::TypeError{std::string() + + "OrderedDict object has no attribute '" + + attr_name + "'"}; + } + + static adt::Result Contains( + axpr::InterpreterBase* interpreter, + const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "OrderedDict.contains() takes 1 argument, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(has_elt, self->Has(interpreter, args.at(0))); + return has_elt; + } + + adt::Result Items(const Self& self, const std::vector&) { + adt::List lst; + lst->reserve(self->items().size()); + for (const auto& [first, second] : self->items()) { + lst->emplace_back(adt::List{first, second}); + } + return lst; + } +}; + +template +struct MethodClassImpl>> { + using Val = ValueT; + using Self = TypeImpl>; + using This = MethodClassImpl; + + adt::Result Call(const Self& self) { + return axpr::Method{self, &This::Construct}; + } + + static adt::Result Construct( + axpr::InterpreterBase* interpreter, + const ValueT&, + const std::vector& args) { + if (args.size() == 0) { + return OrderedDict{}; + } + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "OrderedDict() takes 1 argument but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(lst, args.at(0).template TryGet>()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of OrderedDict() should be list, " + + axpr::GetTypeName(args.at(0)) + " found."}; + OrderedDict ordered_dict{}; + int i = 0; + for (const auto& elt : *lst) { + ADT_LET_CONST_REF(pair, elt.template TryGet>()) + << adt::errors::TypeError{std::string() + "sequence item " + + std::to_string(i) + + " : expected list instance, " + + axpr::GetTypeName(elt) + " found."}; + ADT_CHECK(pair->size() == 2) << adt::errors::TypeError{ + std::string() + "sequence item " + std::to_string(i) + + " : expected 2-item list, " + std::to_string(pair->size()) + + "-item list found."}; + ADT_RETURN_IF_ERR( + ordered_dict->Insert(interpreter, pair->at(0), pair->at(1))); + ++i; + } + return ordered_dict; + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/packed_args.h b/paddle/ap/include/axpr/packed_args.h new file mode 100644 index 00000000000000..39f75ce12aa636 --- /dev/null +++ b/paddle/ap/include/axpr/packed_args.h @@ -0,0 +1,57 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +struct PackedArgsImpl { + adt::List args; + axpr::AttrMap kwargs; + + bool operator==(const PackedArgsImpl& other) const { + return this->args == other.args && this->kwargs == other.kwargs; + } +}; + +template +ADT_DEFINE_RC(PackedArgs, PackedArgsImpl); + +template +PackedArgs CastToPackedArgs( + const std::vector& packed_args_vec) { + if (packed_args_vec.size() == 1 && + packed_args_vec.at(0).template Has>()) { + return packed_args_vec.at(0).template Get>(); + } else { + adt::List pos_args{}; + pos_args->reserve(packed_args_vec.size()); + pos_args->assign(packed_args_vec.begin(), packed_args_vec.end()); + return PackedArgs{pos_args, AttrMap{}}; + } +} + +template +struct TypeImpl> : public std::monostate { + using value_type = PackedArgs; + + const char* Name() const { return "__builtin_PackedArgs__"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/packed_args_method_class.h b/paddle/ap/include/axpr/packed_args_method_class.h new file mode 100644 index 00000000000000..a636b1ef1fbcf8 --- /dev/null +++ b/paddle/ap/include/axpr/packed_args_method_class.h @@ -0,0 +1,54 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/packed_args.h" + +namespace ap::axpr { + +template +struct MethodClassImpl> {}; + +template +struct MethodClassImpl>> { + using This = MethodClassImpl>>; + using Self = TypeImpl>; + adt::Result Call(const Self& self) { return &This::Construct; } + + static adt::Result Construct(const ValueT&, + const std::vector& args) { + return This{}.Make(args); + } + + adt::Result Make(const std::vector& args) { + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(positional_args, + TryGetImpl>(args.at(0))); + ADT_LET_CONST_REF(keyword_args_val, + TryGetImpl>(args.at(1))); + axpr::AttrMap keyword_args; + for (const auto& pair_val : *keyword_args_val) { + ADT_LET_CONST_REF(pair, TryGetImpl>(pair_val)); + ADT_CHECK(pair->size() == 2); + ADT_LET_CONST_REF(key, TryGetImpl(pair->at(0))); + keyword_args->Set(key, pair->at(1)); + } + return PackedArgs{positional_args, keyword_args}; + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/pointer_type.h b/paddle/ap/include/axpr/pointer_type.h new file mode 100644 index 00000000000000..2492e01c95e67f --- /dev/null +++ b/paddle/ap/include/axpr/pointer_type.h @@ -0,0 +1,81 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/data_type.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +struct GetPointerTypeNameHelper; + +#define SPECIALIZE_GET_CPP_TYPE_NAME(cpp_type, enum_type) \ + template <> \ + struct GetPointerTypeNameHelper { \ + static const char* Name() { return #cpp_type "_ptr"; } \ + }; \ + template <> \ + struct GetPointerTypeNameHelper { \ + static const char* Name() { return "const_" #cpp_type "_ptr"; } \ + }; +PD_FOR_EACH_DATA_TYPE(SPECIALIZE_GET_CPP_TYPE_NAME); +#undef SPECIALIZE_GET_CPP_TYPE_NAME + +template <> +struct GetPointerTypeNameHelper { + static const char* Name() { return "void_ptr"; } +}; + +template <> +struct GetPointerTypeNameHelper { + static const char* Name() { return "const_void_ptr"; } +}; + +template +struct CppPointerType : public std::monostate { + using std::monostate::monostate; + using type = T; + const char* Name() const { return GetPointerTypeNameHelper::Name(); } +}; + +// clang-format off +using PointerTypeImpl = std::variant< +#define MAKE_POINTER_TYPE_ALTERNATIVE(cpp_type, enum_type) \ + CppPointerType, \ + CppPointerType, + PD_FOR_EACH_DATA_TYPE(MAKE_POINTER_TYPE_ALTERNATIVE) +#undef MAKE_POINTER_TYPE_ALTERNATIVE + CppPointerType, + CppPointerType>; +// clang-format on + +struct PointerType : public PointerTypeImpl { + using PointerTypeImpl::PointerTypeImpl; + ADT_DEFINE_VARIANT_METHODS(PointerTypeImpl); + + const char* Name() const { + return Match([](const auto& impl) { return impl.Name(); }); + } +}; + +template <> +struct TypeImpl : public std::monostate { + using value_type = PointerType; + + const char* Name() const { return "PointerType"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/pointer_type_method_class.h b/paddle/ap/include/axpr/pointer_type_method_class.h new file mode 100644 index 00000000000000..afc241c3da5792 --- /dev/null +++ b/paddle/ap/include/axpr/pointer_type_method_class.h @@ -0,0 +1,136 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/data_type.h" +#include "paddle/ap/include/axpr/int_data_type.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/pointer_type.h" +#include "paddle/ap/include/axpr/pointer_type_util.h" + +namespace ap::axpr { + +template +struct PointerTypeMethodClass { + using This = PointerTypeMethodClass; + using Self = PointerType; + + adt::Result ToString(const Self& self) { + return std::string("PointerType.") + self.Name(); + } + + adt::Result Hash(const Self& self) { + int64_t hash_value = std::hash()("PointerType"); + hash_value = adt::hash_combine(hash_value, self.index()); + return hash_value; + } + + adt::Result GetAttr(const Self& self, const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, attr_name_val.template CastTo()); + if (attr_name == "data_type") { + return GetDataType(self); + } + return adt::errors::AttributeError{ + std::string() + "PointerType has no attribute '" + attr_name + "'"}; + } + + adt::Result GetDataType(const Self& self) { + return GetDataTypeTypeFromPointerType(self); + } + + template + static BuiltinBinaryFunc GetBuiltinBinaryFunc() { + if constexpr (std::is_same_v) { + return &This::EQ; + } else if constexpr (std::is_same_v) { + return &This::NE; + } else { + return adt::Nothing{}; + } + } + + static Result EQ(const ValueT& lhs_val, const ValueT& rhs_val) { + ADT_LET_CONST_REF(lhs, lhs_val.template TryGet()); + ADT_LET_CONST_REF(rhs, rhs_val.template TryGet()); + const auto& pattern_match = + ::common::Overloaded{[](auto lhs, auto rhs) -> ValueT { + return std::is_same_v; + }}; + return std::visit(pattern_match, lhs.variant(), rhs.variant()); + } + + static Result NE(const ValueT& lhs_val, const ValueT& rhs_val) { + ADT_LET_CONST_REF(lhs, lhs_val.template TryGet()); + ADT_LET_CONST_REF(rhs, rhs_val.template TryGet()); + const auto& pattern_match = + ::common::Overloaded{[](auto lhs, auto rhs) -> ValueT { + return !std::is_same_v; + }}; + return std::visit(pattern_match, lhs.variant(), rhs.variant()); + } +}; + +template +struct MethodClassImpl + : public PointerTypeMethodClass {}; + +template +struct TypeImplPointerTypeMethodClass { + using This = TypeImplPointerTypeMethodClass; + using Self = TypeImpl; + + template + const char* PtrTypeName() { + return axpr::CppPointerType{}.Name(); + } + + template + PointerType PtrType() { + return PointerType{CppPointerType{}}; + } + + adt::Result GetAttr(const Self&, const ValueT& attr_name_val) { + ADT_LET_CONST_REF(attr_name, TryGetImpl(attr_name_val)); + static const std::unordered_map map{ +#define MAKE_CPP_TYPE_CASE(cpp_type, enum_type) \ + {PtrTypeName(), PtrType()}, \ + {PtrTypeName(), PtrType()}, + PD_FOR_EACH_DATA_TYPE(MAKE_CPP_TYPE_CASE) +#undef MAKE_CPP_TYPE_CASE +#define MAKE_INT_CPP_TYPE_CASE(cpp_type) \ + {#cpp_type "_ptr", PtrType()}, \ + {"const_" #cpp_type "_ptr", PtrType()}, + AP_FOR_EACH_INT_TYPE(MAKE_INT_CPP_TYPE_CASE) +#undef MAKE_INT_CPP_TYPE_CASE + {PtrTypeName(), PtrType()}, + {PtrTypeName(), PtrType()}, + }; + const auto iter = map.find(attr_name); + if (iter != map.end()) { + return iter->second; + } + return adt::errors::AttributeError{ + std::string() + "class 'PointerType' has no static attribute '" + + attr_name + "'."}; + } +}; + +template +struct MethodClassImpl> + : public TypeImplPointerTypeMethodClass {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/pointer_type_util.h b/paddle/ap/include/axpr/pointer_type_util.h new file mode 100644 index 00000000000000..ca212beee52ab8 --- /dev/null +++ b/paddle/ap/include/axpr/pointer_type_util.h @@ -0,0 +1,102 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/data_type.h" +#include "paddle/ap/include/axpr/pointer_type.h" + +namespace ap::axpr { + +PointerType RemoveConst(const PointerType& ptr_type); +PointerType GetConstPointerTypeFromDataType(const DataType& data_type); +PointerType GetMutablePointerTypeFromDataType(const DataType& data_type); +DataType GetDataTypeTypeFromPointerType(const PointerType& pointer_type); + +namespace detail { + +template +struct TypeConverter; + +#define SPECIALIZE_TYPE_CONVERTER(cpp_type, enum_type) \ + template <> \ + struct TypeConverter> { \ + using remove_const_type = CppPointerType; \ + }; \ + template <> \ + struct TypeConverter> { \ + using remove_const_type = CppPointerType; \ + }; + +PD_FOR_EACH_DATA_TYPE(SPECIALIZE_TYPE_CONVERTER); +#undef SPECIALIZE_TYPE_CONVERTER + +template <> +struct TypeConverter> { + using remove_const_type = CppPointerType; +}; + +template <> +struct TypeConverter> { + using remove_const_type = CppPointerType; +}; + +} // namespace detail + +inline PointerType RemoveConst(const PointerType& ptr_type) { + return ptr_type.Match([](auto impl) { + return PointerType{ + typename detail::TypeConverter::remove_const_type{}}; + }); +} + +inline PointerType GetConstPointerTypeFromDataType(const DataType& data_type) { + return data_type.Match([&](const auto& impl) -> PointerType { + using T = typename std::decay_t::type; + if constexpr (std::is_same_v) { + return CppPointerType{}; + } else { + return CppPointerType{}; + } + }); +} + +inline PointerType GetMutablePointerTypeFromDataType( + const DataType& data_type) { + return data_type.Match([&](const auto& impl) -> PointerType { + using T = typename std::decay_t::type; + if constexpr (std::is_same_v) { + return CppPointerType{}; + } else { + return CppPointerType{}; + } + }); +} + +inline DataType GetDataTypeTypeFromPointerType( + const PointerType& pointer_type) { + return pointer_type.Match( + [&](CppPointerType) -> DataType { + return CppDataType{}; + }, + [&](CppPointerType) -> DataType { + return CppDataType{}; + }, + [&](const auto& impl) -> DataType { + using PtrT = typename std::decay_t::type; + using T = std::remove_const_t>; + return CppDataType{}; + }); +} +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/pointer_value.h b/paddle/ap/include/axpr/pointer_value.h new file mode 100644 index 00000000000000..31dc40d0b2df17 --- /dev/null +++ b/paddle/ap/include/axpr/pointer_value.h @@ -0,0 +1,49 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/data_type.h" +#include "paddle/ap/include/axpr/pointer_type.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +using PointerValueImpl = std::variant< +#define MAKE_ARG_VALUE_ALTERNATIVE(cpp_type, enum_type) \ + cpp_type*, const cpp_type*, + PD_FOR_EACH_DATA_TYPE(MAKE_ARG_VALUE_ALTERNATIVE) void*, + const void* +#undef MAKE_ARG_VALUE_ALTERNATIVE + >; + +struct PointerValue : public PointerValueImpl { + using PointerValueImpl::PointerValueImpl; + ADT_DEFINE_VARIANT_METHODS(PointerValueImpl); + + PointerType GetType() const { + return Match([](auto impl) -> PointerType { + return PointerType{CppPointerType{}}; + }); + } +}; + +template <> +struct TypeImpl : public std::monostate { + using value_type = PointerValue; + + const char* Name() const { return "PointerValue"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/pointer_value_method_class.h b/paddle/ap/include/axpr/pointer_value_method_class.h new file mode 100644 index 00000000000000..7328f37da87953 --- /dev/null +++ b/paddle/ap/include/axpr/pointer_value_method_class.h @@ -0,0 +1,104 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/pointer_value.h" + +namespace ap::axpr { + +template +struct PointerValueMethodClass { + using This = PointerValueMethodClass; + using Self = PointerValue; + + adt::Result ToString(const Self& self) { + return self.Match([](const auto* impl) -> std::string { + std::ostringstream ss; + ss << impl; + return ss.str(); + }); + } + + adt::Result Hash(const Self& self) { + return self.Match([](const auto* impl) -> int64_t { + return reinterpret_cast(impl); + }); + } + + template + static BuiltinUnaryFunc GetBuiltinUnaryFunc() { + return adt::Nothing{}; + } + + template + static BuiltinBinaryFunc GetBuiltinBinaryFunc() { + if constexpr (std::is_same_v) { + return &This::EQ; + } else if constexpr (std::is_same_v) { + return &This::NE; + } else { + return adt::Nothing{}; + } + } + + static Result EQ(const ValueT& lhs_val, const ValueT& rhs_val) { + const auto& opt_lhs = lhs_val.template TryGet(); + ADT_RETURN_IF_ERR(opt_lhs); + const auto& lhs = opt_lhs.GetOkValue(); + const auto& opt_rhs = rhs_val.template TryGet(); + ADT_RETURN_IF_ERR(opt_rhs); + const auto& rhs = opt_rhs.GetOkValue(); + const auto& pattern_match = + ::common::Overloaded{[](auto lhs, auto rhs) -> ValueT { + if constexpr (std::is_same_v) { + return lhs == rhs; + } else { + return false; + } + }}; + return std::visit(pattern_match, lhs.variant(), rhs.variant()); + } + + static Result NE(const ValueT& lhs_val, const ValueT& rhs_val) { + const auto& opt_lhs = lhs_val.template TryGet(); + ADT_RETURN_IF_ERR(opt_lhs); + const auto& lhs = opt_lhs.GetOkValue(); + const auto& opt_rhs = rhs_val.template TryGet(); + ADT_RETURN_IF_ERR(opt_rhs); + const auto& rhs = opt_rhs.GetOkValue(); + const auto& pattern_match = + ::common::Overloaded{[](auto lhs, auto rhs) -> ValueT { + if constexpr (std::is_same_v) { + return lhs != rhs; + } else { + return true; + } + }}; + return std::visit(pattern_match, lhs.variant(), rhs.variant()); + } +}; + +template +struct MethodClassImpl + : public PointerValueMethodClass {}; + +template +struct MethodClassImpl> + : public EmptyMethodClass {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/s_expr.h b/paddle/ap/include/axpr/s_expr.h new file mode 100644 index 00000000000000..39dc375f73a496 --- /dev/null +++ b/paddle/ap/include/axpr/s_expr.h @@ -0,0 +1,61 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/atomic.h" + +namespace ap::axpr { + +struct SExpr; + +// (outer_func (inner_func [args])) +template +struct SListImpl { + std::vector children; + + bool operator==(const SListImpl& other) const { + return this->children == other.children; + } +}; + +template +ADT_DEFINE_RC(SList, const SListImpl); + +// s expression +// expr := aexpr | ([expr]) +using SExprBase = std::variant, SList>; + +struct SExpr : public SExprBase { + using SExprBase::SExprBase; + ADT_DEFINE_VARIANT_METHODS(SExprBase); + + std::string ToSExpression() const; +}; + +} // namespace ap::axpr + +namespace std { + +inline std::ostream& operator<<(std::ostream& os, + const ap::axpr::SExpr& core_expr) { + return os << core_expr.ToSExpression(); +} + +} // namespace std diff --git a/paddle/ap/include/axpr/serializable_list.h b/paddle/ap/include/axpr/serializable_list.h new file mode 100644 index 00000000000000..10c26cb3fe581c --- /dev/null +++ b/paddle/ap/include/axpr/serializable_list.h @@ -0,0 +1,31 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template <> +struct TypeImpl> : public std::monostate { + using value_type = adt::List; + + const char* Name() const { return "SerializableList"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/serializable_list_method_class.h b/paddle/ap/include/axpr/serializable_list_method_class.h new file mode 100644 index 00000000000000..a11996586264ea --- /dev/null +++ b/paddle/ap/include/axpr/serializable_list_method_class.h @@ -0,0 +1,73 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/serializable_list.h" +#include "paddle/ap/include/axpr/starred.h" + +namespace ap::axpr { + +template +struct MethodClassImpl> { + using This = MethodClassImpl>; + using Self = adt::List; + + adt::Result Length(const Self& self) { + return static_cast(self->size()); + } + + adt::Result ToString(const Self& self) { + ADT_LET_CONST_REF(str, SerializableValueHelper{}.ToString(self)); + return str; + } + + adt::Result Hash(const Self& self) { + ADT_LET_CONST_REF(hash_value, SerializableValueHelper{}.Hash(self)); + return hash_value; + } + + adt::Result GetItem(const Self& self, const ValueT& idx) { + return idx.Match( + [&](int64_t index) -> Result { + if (index < 0) { + index += self->size(); + } + if (index >= 0 && index < self->size()) { + return self->at(index).template CastTo(); + } + return adt::errors::IndexError{"list index out of range"}; + }, + [&](const auto&) -> Result { + return adt::errors::TypeError{std::string() + + "list indices must be integers, not " + + axpr::GetTypeName(idx)}; + }); + } + + adt::Result Starred(const Self& self) { + return ap::axpr::Starred{self}; + } +}; + +template +struct MethodClassImpl>> { + using Self = TypeImpl>; + + using This = MethodClassImpl; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/serializable_value.h b/paddle/ap/include/axpr/serializable_value.h new file mode 100644 index 00000000000000..7c622960e0ef13 --- /dev/null +++ b/paddle/ap/include/axpr/serializable_value.h @@ -0,0 +1,147 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/bool.h" +#include "paddle/ap/include/axpr/builtin_func_name_mgr.h" +#include "paddle/ap/include/axpr/builtin_func_type.h" +#include "paddle/ap/include/axpr/class_attrs.h" +#include "paddle/ap/include/axpr/float.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/int.h" +#include "paddle/ap/include/axpr/nothing.h" +#include "paddle/ap/include/axpr/string.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +struct BuiltinFuncVoidPtr { + void* func_ptr; + + bool operator==(const BuiltinFuncVoidPtr& other) const { + return this->func_ptr == other.func_ptr; + } +}; + +struct BuiltinHighOrderFuncVoidPtr { + void* func_ptr; + + bool operator==(const BuiltinHighOrderFuncVoidPtr& other) const { + return this->func_ptr == other.func_ptr; + } +}; + +template +using SerializableValueImpl = std::variant, + TypeImpl, + TypeImpl, + TypeImpl, + TypeImpl, + ClassAttrs, + adt::Nothing, + bool, + int64_t, + double, + std::string, + Function, + adt::List, + AttrMap, + BuiltinFuncVoidPtr, + BuiltinHighOrderFuncVoidPtr>; + +template +struct ClassInstance; + +struct SerializableValue : public SerializableValueImpl { + using SerializableValueImpl::SerializableValueImpl; + + ADT_DEFINE_VARIANT_METHODS(SerializableValueImpl); + + template + ValueT CastTo() const { + return Match( + [&](const BuiltinFuncVoidPtr& func) -> ValueT { + return reinterpret_cast>(func.func_ptr); + }, + [&](const BuiltinHighOrderFuncVoidPtr& func) -> ValueT { + return reinterpret_cast>( + func.func_ptr); + }, + [&](const ClassAttrs& class_attrs) -> ValueT { + return TypeImpl>(class_attrs); + }, + [&](const auto& impl) -> ValueT { return impl; }); + } + + template + static bool IsSerializable(const ValueT& val) { + using TypeT = typename TypeTrait::TypeT; + return val.Match( + [&](const TypeT& type) -> bool { + return type.Match( + [](const TypeImpl&) -> bool { return true; }, + [](const TypeImpl&) -> bool { return true; }, + [](const TypeImpl&) -> bool { return true; }, + [](const TypeImpl&) -> bool { return true; }, + [](const TypeImpl&) -> bool { return true; }, + [](const TypeImpl>&) -> bool { + return true; + }, + [&](const auto&) -> bool { return false; }); + }, + [](const Nothing&) -> bool { return true; }, + [](bool) -> bool { return true; }, + [](int64_t) -> bool { return true; }, + [](double) -> bool { return true; }, + [](const std::string&) -> bool { return true; }, + [](const Function&) -> bool { return true; }, + [](const adt::List&) -> bool { return true; }, + [](const AttrMap&) -> bool { return true; }, + [&](const adt::List& list) -> bool { + for (const auto& elt : *list) { + if (!IsSerializable(elt)) { + return false; + } + } + return true; + }, + [&](const AttrMap& object) -> bool { + for (const auto& [k, v] : object->object->storage) { + if (!IsSerializable(v)) { + return false; + } + } + return true; + }, + [&](const BuiltinFuncType& func) -> bool { + void* func_ptr = reinterpret_cast(func); + return BuiltinFuncNameMgr::Singleton()->Has(func_ptr); + }, + [&](const BuiltinHighOrderFuncType& func) -> bool { + void* func_ptr = reinterpret_cast(func); + return BuiltinFuncNameMgr::Singleton()->Has(func_ptr); + }, + [&](const auto&) -> bool { return false; }); + } + + static std::string SerializableTypeNames() { + return "NoneType, bool, int, float, str, class, function, " + "BuiltinSerializableList, BuiltinSerializableAttrMap"; + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/serializable_value_helper.h b/paddle/ap/include/axpr/serializable_value_helper.h new file mode 100644 index 00000000000000..6a136b57fdec55 --- /dev/null +++ b/paddle/ap/include/axpr/serializable_value_helper.h @@ -0,0 +1,269 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/anf_expr_helper.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/serializable_value.h" + +namespace ap::axpr { + +struct SerializableValueHelper { + template + adt::Result CastFrom(const ValueT& val) { + using RetT = adt::Result; + using TypeT = typename TypeTrait::TypeT; + return val.Match( + [&](const TypeT& type) -> RetT { + return type.Match( + [](const TypeImpl& impl) -> RetT { return impl; }, + [](const TypeImpl& impl) -> RetT { return impl; }, + [](const TypeImpl& impl) -> RetT { return impl; }, + [](const TypeImpl& impl) -> RetT { return impl; }, + [](const TypeImpl& impl) -> RetT { return impl; }, + [](const TypeImpl>& impl) -> RetT { + return impl.class_attrs; + }, + [&](const auto&) -> RetT { + std::ostringstream ss; + ss << "Builtin serializable types are: "; + ss << SerializableValue::SerializableTypeNames(); + ss << " (not include '" << axpr::GetTypeName(val) << "')."; + return adt::errors::ValueError{ss.str()}; + }); + }, + [](const Nothing& impl) -> RetT { return impl; }, + [](bool impl) -> RetT { return impl; }, + [](int64_t impl) -> RetT { return impl; }, + [](double impl) -> RetT { return impl; }, + [](const std::string& impl) -> RetT { return impl; }, + [](const Function& impl) -> RetT { return impl; }, + [](const adt::List& impl) -> RetT { return impl; }, + [](const AttrMap& impl) -> RetT { return impl; }, + [&](const adt::List& list) -> RetT { + return CastListFrom(list); + }, + [&](const AttrMap& object) -> RetT { + return CastObjectFrom(object); + }, + [&](const BuiltinFuncType& func) -> RetT { + auto* func_ptr = reinterpret_cast(func); + ADT_CHECK(BuiltinFuncNameMgr::Singleton()->Has(func_ptr)); + return BuiltinFuncVoidPtr{func_ptr}; + }, + [&](const BuiltinHighOrderFuncType& func) -> RetT { + auto* func_ptr = reinterpret_cast(func); + ADT_CHECK(BuiltinFuncNameMgr::Singleton()->Has(func_ptr)); + return BuiltinHighOrderFuncVoidPtr{func_ptr}; + }, + [&](const auto&) -> RetT { + std::ostringstream ss; + ss << "Builtin serializable types are: "; + ss << SerializableValue::SerializableTypeNames(); + ss << " (not include '" << axpr::GetTypeName(val) << "')."; + return adt::errors::ValueError{ss.str()}; + }); + } + + adt::Result Hash(const SerializableValue& val) { + using RetT = adt::Result; + return val.Match( + [](const TypeImpl&) -> RetT { + int64_t hash_value = + std::hash()(typeid(TypeImpl).name()); + return hash_value; + }, + [](const TypeImpl&) -> RetT { + int64_t hash_value = + std::hash()(typeid(TypeImpl).name()); + return hash_value; + }, + [](const TypeImpl&) -> RetT { + int64_t hash_value = + std::hash()(typeid(TypeImpl).name()); + return hash_value; + }, + [](const TypeImpl&) -> RetT { + int64_t hash_value = + std::hash()(typeid(TypeImpl).name()); + return hash_value; + }, + [](const TypeImpl&) -> RetT { + int64_t hash_value = + std::hash()(typeid(TypeImpl).name()); + return hash_value; + }, + [](const ClassAttrs& class_attrs) -> RetT { + return reinterpret_cast(class_attrs.shared_ptr().get()); + }, + [](const adt::Nothing&) -> RetT { return static_cast(0); }, + [](bool c) -> RetT { return static_cast(c); }, + [](int64_t c) -> RetT { return c; }, + [](double c) -> RetT { + return static_cast(std::hash()(c)); + }, + [](const std::string& c) -> RetT { + return static_cast(std::hash()(c)); + }, + [](const Function& impl) -> RetT { + return impl->GetHashValue(); + }, + [&](const adt::List& lst) -> RetT { + return HashImpl(lst); + }, + [&](const axpr::AttrMap& obj) -> RetT { + return HashImpl(obj); + }, + [&](const BuiltinFuncVoidPtr& func) -> RetT { + return reinterpret_cast(func.func_ptr); + }, + [&](const BuiltinHighOrderFuncVoidPtr& func) -> RetT { + return reinterpret_cast(func.func_ptr); + }); + } + + adt::Result HashImpl(const adt::List& lst) { + int64_t hash_value = 0; + for (const auto& elt : *lst) { + ADT_LET_CONST_REF(elt_hash, Hash(elt)); + hash_value = adt::hash_combine(hash_value, elt_hash); + } + return hash_value; + } + + adt::Result HashImpl( + const axpr::AttrMap& object) { + return reinterpret_cast(object.shared_ptr().get()); + } + + adt::Result ToString(const SerializableValue& val) { + using RetT = adt::Result; + return val.Match( + [](const TypeImpl&) -> RetT { + return TypeImpl{}.Name(); + }, + [](const TypeImpl&) -> RetT { return TypeImpl{}.Name(); }, + [](const TypeImpl&) -> RetT { + return TypeImpl{}.Name(); + }, + [](const TypeImpl&) -> RetT { + return TypeImpl{}.Name(); + }, + [](const TypeImpl&) -> RetT { + return TypeImpl{}.Name(); + }, + [](const ClassAttrs& class_attrs) -> RetT { + return std::string() + "class_name + "'>"; + }, + [](const adt::Nothing&) -> RetT { return "None"; }, + [](bool c) -> RetT { return std::string(c ? "True" : "False"); }, + [](int64_t c) -> RetT { return std::to_string(c); }, + [](double c) -> RetT { return std::to_string(c); }, + [](const std::string& c) -> RetT { + std::ostringstream ss; + ss << std::quoted(c); + return ss.str(); + }, + [](const Function& impl) -> RetT { + const auto& lambda = impl->lambda; + const auto& anf_expr = ConvertCoreExprToAnfExpr(lambda); + ADT_LET_CONST_REF(anf_atomic, + anf_expr.template TryGet>()); + ADT_LET_CONST_REF(anf_lambda, + anf_atomic.template TryGet>()); + AnfExprHelper anf_expr_helper; + ADT_LET_CONST_REF(anf_expr_str, + anf_expr_helper.FunctionToString(anf_lambda)); + return anf_expr_str; + }, + [&](const adt::List& lst) -> RetT { + return ToStringImpl(lst); + }, + [&](const axpr::AttrMap& obj) -> RetT { + return ToStringImpl(obj); + }, + [&](const BuiltinFuncVoidPtr& func) -> RetT { + const auto& name_info = + BuiltinFuncNameMgr::Singleton()->OptGet(func.func_ptr); + ADT_CHECK(name_info.has_value()); + return name_info.value().ToString(); + }, + [&](const BuiltinHighOrderFuncVoidPtr& func) -> RetT { + const auto& name_info = + BuiltinFuncNameMgr::Singleton()->OptGet(func.func_ptr); + ADT_CHECK(name_info.has_value()); + return name_info.value().ToString(); + }); + } + + adt::Result ToStringImpl( + const adt::List& lst) { + std::ostringstream ss; + ss << "["; + int i = 0; + for (const auto& elt : *lst) { + if (i++ > 0) { + ss << ", "; + } + ADT_LET_CONST_REF(str, ToString(elt)); + ss << str; + } + ss << "]"; + return ss.str(); + } + + adt::Result ToStringImpl( + const axpr::AttrMap& object) { + std::ostringstream ss; + ss << "{"; + int i = 0; + for (const auto& [k, v] : object->storage) { + if (i++ > 0) { + ss << ", "; + } + ss << std::quoted(k); + ss << ":"; + ADT_LET_CONST_REF(str, ToString(v)); + ss << str; + } + ss << "}"; + return ss.str(); + } + + template + adt::Result CastListFrom(const adt::List& lst) { + adt::List ret; + ret->reserve(lst->size()); + for (const auto& elt : *lst) { + ADT_LET_CONST_REF(converted, CastFrom(elt)); + ret->emplace_back(converted); + } + return ret; + } + + template + adt::Result CastObjectFrom(const AttrMap& obj) { + AttrMap ret_object{}; + for (const auto& [k, v] : obj->storage) { + ADT_LET_CONST_REF(converted, CastFrom(v)); + ret_object->Set(k, converted); + } + return AttrMap{ret_object}; + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/starred.h b/paddle/ap/include/axpr/starred.h new file mode 100644 index 00000000000000..9be9ece3d076fe --- /dev/null +++ b/paddle/ap/include/axpr/starred.h @@ -0,0 +1,42 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +struct StarredImpl { + ValueT obj; + + bool operator==(const StarredImpl& other) const { + return other.obj == this->obj; + } +}; + +template +ADT_DEFINE_RC(Starred, const StarredImpl); + +template +struct TypeImpl> : public std::monostate { + using value_type = Starred; + + const char* Name() const { return "starred"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/starred_method_class.h b/paddle/ap/include/axpr/starred_method_class.h new file mode 100644 index 00000000000000..4deaa0ced6d733 --- /dev/null +++ b/paddle/ap/include/axpr/starred_method_class.h @@ -0,0 +1,31 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/method.h" +#include "paddle/ap/include/axpr/method_class.h" + +namespace ap::axpr { + +template +struct MethodClassImpl> + : public EmptyMethodClass {}; + +template +struct MethodClassImpl>> + : public EmptyMethodClass {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/string.h b/paddle/ap/include/axpr/string.h new file mode 100644 index 00000000000000..97ebd324796133 --- /dev/null +++ b/paddle/ap/include/axpr/string.h @@ -0,0 +1,31 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template <> +struct TypeImpl : public std::monostate { + using value_type = std::string; + + const char* Name() const { return "str"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/string_method_class.h b/paddle/ap/include/axpr/string_method_class.h new file mode 100644 index 00000000000000..66ec2063faf5dc --- /dev/null +++ b/paddle/ap/include/axpr/string_method_class.h @@ -0,0 +1,133 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/string_util.h" + +namespace ap::axpr { + +template +struct StringMethodClass { + using This = StringMethodClass; + using Self = std::string; + using Val = ValueT; + + adt::Result ToString(const Self& self) { return self; } + + adt::Result Hash(const Self& self) { + return static_cast(std::hash()(self)); + } + + adt::Result GetAttr(const Self& self, const Val& attr_name_val) { + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + if (attr_name == "replace") { + return axpr::Method{self, &This::StaticReplace}; + } + + if (attr_name == "join") { + return axpr::Method{self, + &axpr::WrapAsBuiltinFuncType}; + } + return adt::errors::TypeError{}; + } + + adt::Result Join(const Self& self, const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "join() takes 1 argument but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(lst, args.at(0).template TryGet>()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of join() should be 'list' not '" + + axpr::GetTypeName(args.at(0)) + "'."}; + std::ostringstream ss; + int i = 0; + for (const auto& elt : *lst) { + ADT_LET_CONST_REF(item, elt.template TryGet()) + << adt::errors::TypeError{std::string() + "sequence item " + + std::to_string(i) + + ": expected str instance, " + + axpr::GetTypeName(elt) + " found"}; + if (i++ > 0) { + ss << self; + } + ss << item; + } + return ss.str(); + } + + static adt::Result StaticReplace(const Val& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template TryGet()); + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string() + "'str.replace' takes 2 arguments but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(pattern, args.at(0).template TryGet()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of 'str.replace' should be a str"}; + ADT_LET_CONST_REF(replacement, args.at(1).template TryGet()) + << adt::errors::TypeError{ + std::string() + + "the argument 2 of 'str.replace' should be a str"}; + return This{}.Replace(self, pattern, replacement); + } + + std::string Replace(std::string self, + const std::string& pattern, + const std::string& replacement) { + while (true) { + std::size_t pos = self.find(pattern); + if (pos == std::string::npos) { + break; + } + self = self.replace(pos, pattern.size(), replacement); + } + return self; + } + + template + static BuiltinBinaryFunc GetBuiltinBinaryFunc() { + if constexpr (ConvertBuiltinSymbolToArithmetic< + BuiltinBinarySymbol>::convertible) { + using ArithmeticOp = typename ConvertBuiltinSymbolToArithmetic< + BuiltinBinarySymbol>::arithmetic_op_type; + return &This::template BinaryFunc; + } else { + return adt::Nothing{}; + } + } + + template + static adt::Result BinaryFunc(const Val& lhs_val, const Val& rhs_val) { + const auto& opt_lhs = lhs_val.template TryGet(); + ADT_RETURN_IF_ERR(opt_lhs); + const auto& lhs = opt_lhs.GetOkValue(); + return BuiltinStringBinary(lhs, rhs_val); + } +}; + +template +struct MethodClassImpl : public StringMethodClass {}; + +template +struct MethodClassImpl> + : public EmptyMethodClass {}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/string_util.h b/paddle/ap/include/axpr/string_util.h new file mode 100644 index 00000000000000..e019a9abe3173f --- /dev/null +++ b/paddle/ap/include/axpr/string_util.h @@ -0,0 +1,112 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/constants.h" +#include "paddle/ap/include/axpr/method_class.h" + +namespace ap::axpr { + +template +struct BuiltinStringBinaryHelper { + static Result Call(const std::string& str, const T& rhs) { + return adt::errors::TypeError{ + std::string() + "unsupported operand types for " + + ArithmeticOp::Name() + ": 'str' and '" + TypeImpl{}.Name() + "'"}; + } +}; + +template +struct BuiltinStringBinaryHelper { + static Result Call(const std::string& lhs, const std::string& rhs) { + return ArithmeticAdd::Call(lhs, rhs); + } +}; + +template +struct BuiltinStringBinaryHelper { + static Result Call(const std::string& lhs, const int64_t& rhs_val) { + return lhs + std::to_string(rhs_val); + } +}; + +template +struct BuiltinStringBinaryHelper { + static Result Call(const std::string& lhs, const bool& rhs_val) { + return lhs + (rhs_val ? "True" : "False"); + } +}; + +#define SPECIALIZE_BuiltinStringBinaryHelper_string_cmp(cls_name) \ + template \ + struct BuiltinStringBinaryHelper { \ + static Result Call(const std::string& lhs, const std::string& rhs) { \ + return cls_name::Call(lhs, rhs); \ + } \ + }; +SPECIALIZE_BuiltinStringBinaryHelper_string_cmp(ArithmeticEQ); +SPECIALIZE_BuiltinStringBinaryHelper_string_cmp(ArithmeticNE); +SPECIALIZE_BuiltinStringBinaryHelper_string_cmp(ArithmeticGT); +SPECIALIZE_BuiltinStringBinaryHelper_string_cmp(ArithmeticGE); +SPECIALIZE_BuiltinStringBinaryHelper_string_cmp(ArithmeticLT); +SPECIALIZE_BuiltinStringBinaryHelper_string_cmp(ArithmeticLE); +#undef SPECIALIZE_BuiltinStringBinaryHelper_string + +template +struct BuiltinStringBinaryHelper { + static Result Call(const std::string& lhs, int64_t size) { + size = (size > 0 ? size : 0); + std::ostringstream ss; + for (int i = 0; i < size; ++i) { + ss << lhs; + } + return ss.str(); + } +}; + +template +struct BuiltinStringBinaryHelper { + static Result Call(const std::string& lhs, bool size) { + std::ostringstream ss; + for (int i = 0; i < static_cast(size); ++i) { + ss << lhs; + } + return ss.str(); + } +}; + +template +Result BuiltinStringBinary(const std::string& str, const Val& rhs_val) { + return rhs_val.Match( + [&](const BuiltinClassInstance& impl) -> Result { + return adt::errors::TypeError{ + std::string() + "unsupported operand types for " + + ArithmeticOp::Name() + ": 'str' and '" + + impl.type.class_attrs()->class_name + "'"}; + }, + [&](const ClassInstance& impl) -> Result { + return adt::errors::TypeError{std::string() + + "unsupported operand types for " + + ArithmeticOp::Name() + ": 'str' and '" + + impl->type.class_attrs->class_name + "'"}; + }, + [&](const auto& rhs) -> Result { + using T = std::decay_t; + return BuiltinStringBinaryHelper::Call(str, rhs); + }); +} + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/to_string.h b/paddle/ap/include/axpr/to_string.h new file mode 100644 index 00000000000000..ccbf0e205449af --- /dev/null +++ b/paddle/ap/include/axpr/to_string.h @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/method_class.h" + +namespace ap::axpr { + +template +class InterpreterBase; + +template +adt::Result ToString(InterpreterBase* interpreter, + const ValueT& val) { + const auto& func = MethodClass::ToString(val); + using RetT = adt::Result; + return func.Match( + [&](const adt::Nothing&) -> RetT { + return adt::errors::TypeError{GetTypeName(val) + + " class has no __str__ function"}; + }, + [&](adt::Result (*unary_func)(const ValueT&)) -> RetT { + ADT_LET_CONST_REF(str_val, unary_func(val)); + ADT_LET_CONST_REF(str, str_val.template TryGet()); + return str; + }, + [&](adt::Result (*unary_func)(InterpreterBase*, + const ValueT&)) -> RetT { + ADT_LET_CONST_REF(str_val, unary_func(interpreter, val)); + ADT_LET_CONST_REF(str, str_val.template TryGet()); + return str; + }); +} + +template +std::string ToDebugString(InterpreterBase* interpreter, + const ValueT& val) { + const auto& str = ToString(interpreter, val); + if (str.HasError()) { + return "[invalid debug string]"; + } + return str.GetOkValue(); +} + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/type.h b/paddle/ap/include/axpr/type.h new file mode 100644 index 00000000000000..65f91a588f4215 --- /dev/null +++ b/paddle/ap/include/axpr/type.h @@ -0,0 +1,75 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/ap/include/axpr/adt.h" + +namespace ap::axpr { + +template +struct Type; + +template +struct TypeImpl {}; + +template +struct TypeImpl> : public std::monostate { + using value_type = Type; + + const char* Name() const { return "type"; } +}; + +template +using TypeBase = std::variant, TypeImpl...>; + +template +struct Type : public TypeBase, Ts...> { + using TypeBase, Ts...>::TypeBase; + + ADT_DEFINE_VARIANT_METHODS(TypeBase, Ts...>); + + std::string Name() const { + return Match([](const auto& impl) -> std::string { return impl.Name(); }); + } +}; + +namespace detail { + +template +struct IsTypeHelper { + static constexpr const bool value = false; +}; + +template +struct IsTypeHelper> { + static constexpr const bool value = true; +}; + +} // namespace detail + +template +constexpr bool IsType() { + return detail::IsTypeHelper::value; +} + +template +struct TypeTrait { + using VariantT = std::decay_t().variant())>; + using TypeT = std::variant_alternative_t<0, VariantT>; +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/type_method_class.h b/paddle/ap/include/axpr/type_method_class.h new file mode 100644 index 00000000000000..f168d38f3367d7 --- /dev/null +++ b/paddle/ap/include/axpr/type_method_class.h @@ -0,0 +1,80 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/class_instance.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +template +struct MethodClassImpl> {}; + +template +struct MethodClassImpl>> { + using Self = TypeImpl>; + using This = MethodClassImpl; + + adt::Result Call(const Self& value) { + return &This::StaticGetOrConstruct; + } + + static adt::Result StaticGetOrConstruct( + const ValueT& self_val, const std::vector& args) { + if (args.size() == 1) { + return GetType(args.at(0)); + } + if (args.size() == 3) { + return This{}.MakeClass(args.at(0), args.at(1), args.at(2)); + } + return adt::errors::TypeError{std::string() + + "type() takes 1 or 3 arguments, but " + + std::to_string(args.size()) + " were given."}; + } + + adt::Result MakeClass(const ValueT& class_name_val, + const ValueT& superclasses_val, + const ValueT& attributes_object) { + ADT_LET_CONST_REF(class_name, class_name_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + "the argument 1 of type() should be str not " + + GetTypeName(class_name_val)}; + adt::List>> superclasses; + { + ADT_LET_CONST_REF(superclass_vals, + superclasses_val.template TryGet>()) + << adt::errors::TypeError{ + std::string() + + "the argument 2 of type() should be list not " + + GetTypeName(superclasses_val)}; + superclasses->reserve(superclass_vals->size()); + for (const auto& superclass_val : *superclass_vals) { + ADT_LET_CONST_REF( + type_impl, + TryGetTypeImpl>>(superclass_val)); + superclasses->emplace_back(type_impl.class_attrs.shared_ptr()); + } + } + ADT_LET_CONST_REF( + attrs, + attributes_object.template TryGet>()); + ClassAttrs class_attrs{class_name, superclasses, attrs}; + return TypeImpl>{class_attrs}; + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/type_util.h b/paddle/ap/include/axpr/type_util.h new file mode 100644 index 00000000000000..8eba597f67ff05 --- /dev/null +++ b/paddle/ap/include/axpr/type_util.h @@ -0,0 +1,75 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/data_type.h" +#include "paddle/ap/include/axpr/data_value.h" +#include "paddle/ap/include/axpr/mutable_list.h" +#include "paddle/ap/include/axpr/mutable_ordered_dict.h" +#include "paddle/ap/include/axpr/ordered_dict.h" +#include "paddle/ap/include/axpr/packed_args.h" +#include "paddle/ap/include/axpr/pointer_type.h" +#include "paddle/ap/include/axpr/pointer_value.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::axpr { + +namespace detail { + +template +struct GetTypeName2TypeHelper; + +template +struct GetTypeName2TypeHelper { + static void Call(AttrMap*) {} +}; + +template +struct GetTypeName2TypeHelper { + static void Call(AttrMap* ret) { + TypeImpl type_impl{}; + ValueT type{type_impl}; + (*ret)->Set(type_impl.Name(), type); + GetTypeName2TypeHelper::Call(ret); + } +}; + +} // namespace detail + +template +AttrMap GetObjectTypeName2Type() { + AttrMap object; + detail::GetTypeName2TypeHelper::TypeT, + Nothing, + bool, + int64_t, + double, + std::string, + DataType, + DataValue, + PointerType, + PointerValue, + MutableList, + OrderedDict, + MutableOrderedDict, + PackedArgs, + AttrMap, + ValueImplTypes...>::Call(&object); + return object; +} + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/unary_func.h b/paddle/ap/include/axpr/unary_func.h new file mode 100644 index 00000000000000..36ba764a7470f2 --- /dev/null +++ b/paddle/ap/include/axpr/unary_func.h @@ -0,0 +1,53 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace ap::axpr { + +#define PEXPR_FOR_EACH_UNARY_OP(_) \ + _(Not, !) \ + _(Neg, -) + +#define DEFINE_ARITHMETIC_UNARY_OP(name, op) \ + struct Arithmetic##name { \ + static constexpr const char* Name() { return #op; } \ + \ + template \ + static auto Call(const LhsT& val) { \ + return op val; \ + } \ + }; +PEXPR_FOR_EACH_UNARY_OP(DEFINE_ARITHMETIC_UNARY_OP); +#undef DEFINE_ARITHMETIC_UNARY_OP + +template +struct BoolIntDoubleUnary { + static constexpr const char* Name() { return ArithmeticOp::Name(); } + template + static auto Call(T operand) { + auto ret = ArithmeticOp::Call(operand); + using RetT = decltype(ret); + if constexpr (std::is_same_v) { + return ret; + } else if constexpr (std::is_integral_v) { + return static_cast(ret); + } else { + static_assert(std::is_floating_point::value, ""); + return static_cast(ret); + } + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/value.h b/paddle/ap/include/axpr/value.h new file mode 100644 index 00000000000000..b6962a3f98e278 --- /dev/null +++ b/paddle/ap/include/axpr/value.h @@ -0,0 +1,193 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/bool.h" +#include "paddle/ap/include/axpr/builtin_func_type.h" +#include "paddle/ap/include/axpr/builtin_high_order_func_type.h" +#include "paddle/ap/include/axpr/builtin_serializable_attr_map.h" +#include "paddle/ap/include/axpr/builtin_symbol.h" +#include "paddle/ap/include/axpr/class_instance.h" +#include "paddle/ap/include/axpr/closure.h" +#include "paddle/ap/include/axpr/continuation.h" +#include "paddle/ap/include/axpr/data_type.h" +#include "paddle/ap/include/axpr/data_value.h" +#include "paddle/ap/include/axpr/environment.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/axpr/float.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/int.h" +#include "paddle/ap/include/axpr/list.h" +#include "paddle/ap/include/axpr/method.h" +#include "paddle/ap/include/axpr/mutable_list.h" +#include "paddle/ap/include/axpr/mutable_ordered_dict.h" +#include "paddle/ap/include/axpr/nothing.h" +#include "paddle/ap/include/axpr/ordered_dict.h" +#include "paddle/ap/include/axpr/packed_args.h" +#include "paddle/ap/include/axpr/pointer_type.h" +#include "paddle/ap/include/axpr/pointer_value.h" +#include "paddle/ap/include/axpr/serializable_list.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/axpr/starred.h" +#include "paddle/ap/include/axpr/string.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/type_util.h" + +namespace ap::axpr { + +using adt::Nothing; + +template +using ValueBase = std::variant, + adt::List, + MutableList, + AttrMap, + AttrMap, + OrderedDict, + MutableOrderedDict, + BuiltinClassInstance, + ClassInstance, + PackedArgs, + Starred, + Function, + Closure, + Continuation, + Method, + builtin_symbol::Symbol, + BuiltinFuncType, + BuiltinHighOrderFuncType, + Ts...>, + Nothing, + bool, + int64_t, + double, + std::string, + DataType, + DataValue, + PointerType, + PointerValue, + adt::List, + adt::List, + MutableList, + AttrMap, + AttrMap, + OrderedDict, + MutableOrderedDict, + BuiltinClassInstance, + ClassInstance, + PackedArgs, + Starred, + Function, + Closure, + Continuation, + Method, + builtin_symbol::Symbol, + BuiltinFuncType, + BuiltinHighOrderFuncType, + Ts...>; +template +ValueT GetType(const ValueT& value) { + return value.Match( + [](const BuiltinClassInstance& impl) -> ValueT { + return impl.type; + }, + [](const ClassInstance& impl) -> ValueT { return impl->type; }, + [](const auto& impl) -> ValueT { + using T = std::decay_t; + return TypeImpl{}; + }); +} + +template +adt::Result::TypeT> CastToType(const ValueT& value) { + ADT_LET_CONST_REF(type, + value.template TryGet::TypeT>()); + return type; +} + +template +adt::Result TryGetTypeImpl(const ValueT& value) { + ADT_LET_CONST_REF(type, CastToType(value)); + ADT_LET_CONST_REF(type_impl, type.template TryGet()); + return type_impl; +} + +template +adt::Result TryGetBuiltinClassInstance(const ValueT& val) { + ADT_LET_CONST_REF(instance, + val.template TryGet>()); + ADT_LET_CONST_REF(ret, instance.template TryGet()); + return ret; +} + +template +adt::Result Get(const ValueT& val) { + using TypeT = typename TypeTrait::TypeT; + if constexpr (ValueT::template IsMyAlternative()) { + return val.template TryGet(); + } else if constexpr (TypeT::template IsMyAlternative()) { + return TryGetTypeImpl(val); + } else { + return TryGetBuiltinClassInstance(val); + } +} + +template +adt::Result CastableTo(const ValueT& val) { + using TypeT = typename TypeTrait::TypeT; + if constexpr (ValueT::template IsMyAlternative()) { + return val.template Has(); + } else if constexpr (TypeT::template IsMyAlternative()) { + ADT_LET_CONST_REF(type, CastToType(val)); + return type.template Has(); + } else { + ADT_LET_CONST_REF(instance, + val.template TryGet>()); + return instance.template Has(); + } +} + +struct Value : public ValueBase { + using ValueBase::ValueBase; + ADT_DEFINE_VARIANT_METHODS(ValueBase); + + static axpr::AttrMap GetExportedTypes() { + return axpr::GetObjectTypeName2Type(); + } + + template + adt::Result CastTo() const { + return axpr::Get(*this); + } + + template + bool CastableTo() const { + const auto& ret = axpr::CastableTo(*this); + return ret.HasOkValue() ? ret.GetOkValue() : false; + } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/axpr/value_method_class.h b/paddle/ap/include/axpr/value_method_class.h new file mode 100644 index 00000000000000..420ccd0e37d4b8 --- /dev/null +++ b/paddle/ap/include/axpr/value_method_class.h @@ -0,0 +1,44 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/attr_map_method_class.h" +#include "paddle/ap/include/axpr/bool_method_class.h" +#include "paddle/ap/include/axpr/builtin_class_instance_method_class.h" +#include "paddle/ap/include/axpr/builtin_func_type_method_class.h" +#include "paddle/ap/include/axpr/builtin_high_order_func_type_method_class.h" +#include "paddle/ap/include/axpr/builtin_serializable_attr_map_method_class.h" +#include "paddle/ap/include/axpr/builtin_symbol_method_class.h" +#include "paddle/ap/include/axpr/class_instance_method_class.h" +#include "paddle/ap/include/axpr/closure_method_class.h" +#include "paddle/ap/include/axpr/continuation_method_class.h" +#include "paddle/ap/include/axpr/data_type_method_class.h" +#include "paddle/ap/include/axpr/data_value_method_class.h" +#include "paddle/ap/include/axpr/float_method_class.h" +#include "paddle/ap/include/axpr/function_method_class.h" +#include "paddle/ap/include/axpr/int_method_class.h" +#include "paddle/ap/include/axpr/list_method_class.h" +#include "paddle/ap/include/axpr/method_method_class.h" +#include "paddle/ap/include/axpr/mutable_list_method_class.h" +#include "paddle/ap/include/axpr/mutable_ordered_dict_method_class.h" +#include "paddle/ap/include/axpr/nothing_method_class.h" +#include "paddle/ap/include/axpr/ordered_dict_method_class.h" +#include "paddle/ap/include/axpr/packed_args_method_class.h" +#include "paddle/ap/include/axpr/pointer_type_method_class.h" +#include "paddle/ap/include/axpr/pointer_value_method_class.h" +#include "paddle/ap/include/axpr/serializable_list_method_class.h" +#include "paddle/ap/include/axpr/starred_method_class.h" +#include "paddle/ap/include/axpr/string_method_class.h" +#include "paddle/ap/include/axpr/type_method_class.h" diff --git a/paddle/ap/include/code_gen/arg_source_ctx.h b/paddle/ap/include/code_gen/arg_source_ctx.h new file mode 100644 index 00000000000000..bf9eb0364a403f --- /dev/null +++ b/paddle/ap/include/code_gen/arg_source_ctx.h @@ -0,0 +1,206 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" + +namespace ap::code_gen { + +struct NativeIrValueSource { + int native_ir_value_index; + + bool operator==(const NativeIrValueSource& other) const { + return this->native_ir_value_index == other.native_ir_value_index; + } +}; + +struct PackedIrValueSource { + int packed_ir_value_index; + int tensor_member_index; + + bool operator==(const PackedIrValueSource& other) const { + return this->packed_ir_value_index == other.packed_ir_value_index && + this->tensor_member_index == other.tensor_member_index; + } +}; + +using TensorSourceImpl = std::variant; + +struct TensorSource : public TensorSourceImpl { + using TensorSourceImpl::TensorSourceImpl; + ADT_DEFINE_VARIANT_METHODS(TensorSourceImpl); +}; + +struct InTensorSource { + TensorSource tensor_source; + + bool operator==(const InTensorSource& other) const { + return this->tensor_source == other.tensor_source; + } +}; + +struct OutTensorSource { + TensorSource tensor_source; + + bool operator==(const OutTensorSource& other) const { + return this->tensor_source == other.tensor_source; + } +}; + +using InOutTensorSourceImpl = std::variant; + +struct InOutTensorSource : public InOutTensorSourceImpl { + using InOutTensorSourceImpl::InOutTensorSourceImpl; + ADT_DEFINE_VARIANT_METHODS(InOutTensorSourceImpl); +}; + +struct ShapeDimSource { + InOutTensorSource tensor_source; + int dim_axis; + + bool operator==(const ShapeDimSource& other) const { + return this->tensor_source == other.tensor_source && + this->dim_axis == other.dim_axis; + } +}; + +struct DataDimSource { + InOutTensorSource tensor_source; + int dim_axis; + + bool operator==(const DataDimSource& other) const { + return this->tensor_source == other.tensor_source && + this->dim_axis == other.dim_axis; + } +}; + +using DimSourceImpl = std::variant; +struct DimSource : public DimSourceImpl { + using DimSourceImpl::DimSourceImpl; + ADT_DEFINE_VARIANT_METHODS(DimSourceImpl); +}; + +template +struct ArgSourceCtxImpl { + std::vector> input_and_tensor_source_pairs; + std::vector> + output_and_tensor_source_pairs; + std::vector> + dim_expr_and_dim_source_pairs; + std::unordered_map dim_expr2dim_source; + + std::optional GetInputTensorSource( + const BirNode& node) const { + for (const auto& [k, v] : this->input_and_tensor_source_pairs) { + if (k == node) { + return &v; + } + } + return std::nullopt; + } + + std::optional GetOutputTensorSource( + const BirNode& node) const { + for (const auto& [k, v] : this->output_and_tensor_source_pairs) { + if (k == node) { + return &v; + } + } + return std::nullopt; + } + + std::optional GetDimExprSource( + const symbol::DimExpr& dim_expr) const { + const auto& iter = this->dim_expr2dim_source.find(dim_expr); + if (iter == this->dim_expr2dim_source.end()) { + return std::nullopt; + } + return &iter->second; + } + + bool HasDirectOrIndirectDimExprSource(const symbol::DimExpr& dim_expr) const { + if (GetDimExprSource(dim_expr).has_value()) { + return true; + } + return dim_expr.Match([&](const auto& impl) { + return HasDirectOrIndirectDimExprSourceImpl(impl); + }); + } + + private: + bool HasDirectOrIndirectDimExprSourceImpl(int64_t) const { return true; } + + bool HasDirectOrIndirectDimExprSourceImpl(const std::string& dim_expr) const { + return GetDimExprSource(dim_expr).has_value(); + } + + using Negative = symbol::Negative; + bool HasDirectOrIndirectDimExprSourceImpl(const Negative& dim_expr) const { + return HasDirectOrIndirectUnaryDimExprSource(dim_expr); + } + + using Reciprocal = symbol::Reciprocal; + bool HasDirectOrIndirectDimExprSourceImpl(const Reciprocal& dim_expr) const { + return HasDirectOrIndirectUnaryDimExprSource(dim_expr); + } + + template + bool HasDirectOrIndirectUnaryDimExprSource(const T& dim_expr) const { + const auto& [operand] = *dim_expr; + return HasDirectOrIndirectDimExprSource(operand); + } + + using Add = symbol::Add; + bool HasDirectOrIndirectDimExprSourceImpl(const Add& dim_expr) const { + return HasDirectOrIndirectVariadicDimExprSource(dim_expr); + } + + using Mul = symbol::Mul; + bool HasDirectOrIndirectDimExprSourceImpl(const Mul& dim_expr) const { + return HasDirectOrIndirectVariadicDimExprSource(dim_expr); + } + + using Max = symbol::Max; + bool HasDirectOrIndirectDimExprSourceImpl(const Max& dim_expr) const { + return HasDirectOrIndirectVariadicDimExprSource(dim_expr); + } + + using Min = symbol::Min; + bool HasDirectOrIndirectDimExprSourceImpl(const Min& dim_expr) const { + return HasDirectOrIndirectVariadicDimExprSource(dim_expr); + } + + using Broadcast = symbol::Broadcast; + bool HasDirectOrIndirectDimExprSourceImpl(const Broadcast& dim_expr) const { + return HasDirectOrIndirectVariadicDimExprSource(dim_expr); + } + + template + bool HasDirectOrIndirectVariadicDimExprSource(const T& dim_expr) const { + const auto& [operands] = dim_expr; + for (const auto& operand : *operands) { + if (!HasDirectOrIndirectDimExprSource(operand)) { + return false; + } + } + return true; + } +}; + +template +ADT_DEFINE_RC(ArgSourceCtx, ArgSourceCtxImpl); + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/arg_source_helper.h b/paddle/ap/include/code_gen/arg_source_helper.h new file mode 100644 index 00000000000000..362f44e8750b72 --- /dev/null +++ b/paddle/ap/include/code_gen/arg_source_helper.h @@ -0,0 +1,361 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/anf_expr_builder.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/code_gen/arg_source_ctx.h" +#include "paddle/ap/include/code_gen/kernel_arg_id.h" + +namespace ap::code_gen { + +template +struct ArgSourceHelper { + const ArgSourceCtx& arg_source_ctx; + + adt::Result> + MakeRuntimeKerneArgsGetter( + const std::list>& kernel_arg_ids) const { + auto GetBody = + [&](axpr::LetVar* dispatch_ctx) -> adt::Result { + std::vector items; + items.reserve(kernel_arg_ids.size()); + for (const auto& kernel_arg_id : kernel_arg_ids) { + ADT_LET_CONST_REF(elt, MakeRuntimeGetter(dispatch_ctx, kernel_arg_id)); + items.emplace_back(*elt); + } + auto* ctx = dispatch_ctx->ctx(); + auto* ret_ptr = &ctx->Var(ctx->NewTmpVarName()); + *ret_ptr = ctx->Var(axpr::kBuiltinList()).Call(items); + return ret_ptr; + }; + ADT_LET_CONST_REF(lambda, CreateLambda("ctx", GetBody)); + return axpr::Function{lambda, std::nullopt}; + } + + adt::Result> + MakeRuntimeKerneArgGetter(const KernelArgId& kernel_arg_id) const { + auto GetBody = + [&](axpr::LetVar* dispatch_ctx) -> adt::Result { + return MakeRuntimeGetter(dispatch_ctx, kernel_arg_id); + }; + ADT_LET_CONST_REF(lambda, CreateLambda("ctx", GetBody)); + return axpr::Function{lambda, std::nullopt}; + } + + private: + adt::Result MakeRuntimeGetter( + axpr::LetVar* dispatch_ctx, + const KernelArgId& kernel_arg_id) const { + return kernel_arg_id.Match([&](const auto& impl) { + return MakeRuntimeGetterImpl(dispatch_ctx, impl); + }); + } + + adt::Result MakeRuntimeGetterImpl( + axpr::LetVar* dispatch_ctx, + const InTensorDataPtrKernelArgId& kernel_arg_id) const { + const auto& opt_in_tensor_source = + arg_source_ctx->GetInputTensorSource(kernel_arg_id->ir_value); + ADT_CHECK(opt_in_tensor_source.has_value()); + const auto* in_tensor_source = opt_in_tensor_source.value(); + ADT_LET_CONST_REF( + tensor_var_ptr, + MakeGetterAnfExprByInTensorSource(dispatch_ctx, *in_tensor_source)); + return &tensor_var_ptr->Attr("data_ptr"); + } + + adt::Result MakeRuntimeGetterImpl( + axpr::LetVar* dispatch_ctx, + const OutTensorDataPtrKernelArgId& kernel_arg_id) const { + const auto& opt_out_tensor_source = + arg_source_ctx->GetOutputTensorSource(kernel_arg_id->ir_value); + ADT_CHECK(opt_out_tensor_source.has_value()); + const auto* out_tensor_source = opt_out_tensor_source.value(); + ADT_LET_CONST_REF( + tensor_var_ptr, + MakeGetterAnfExprByOutTensorSource(dispatch_ctx, *out_tensor_source)); + return &tensor_var_ptr->Attr("data_ptr"); + } + + template + adt::Result> CreateLambda( + const std::string& dispatch_ctx_name, const GetVarPtrT& GetVarPtr) const { + axpr::LambdaExprBuilder lmbd; + auto GetBody = + [&](axpr::LetContext& ctx) -> adt::Result { // NOLINT + ADT_LET_CONST_REF(var_ptr, GetVarPtr(&ctx.Var(dispatch_ctx_name))); + return static_cast(*var_ptr); + }; + ADT_LET_CONST_REF(anf_expr, lmbd.TryLambda({dispatch_ctx_name}, GetBody)); + const auto& core_expr = ConvertAnfExprToCoreExpr(anf_expr); + ADT_LET_CONST_REF( + atomic, core_expr.template TryGet>()); + ADT_LET_CONST_REF(lambda, + atomic.template TryGet>()); + return lambda; + } + + adt::Result MakeGetterAnfExprByInTensorSource( + axpr::LetVar* dispatch_ctx, + const InTensorSource& in_tensor_source) const { + auto& inputs_var = dispatch_ctx->Attr("inputs"); + return MakeGetterAnfExprByTensorSource(&inputs_var, + in_tensor_source.tensor_source); + } + + adt::Result MakeGetterAnfExprByOutTensorSource( + axpr::LetVar* dispatch_ctx, + const OutTensorSource& out_tensor_source) const { + auto& outputs_var = dispatch_ctx->Attr("outputs"); + return MakeGetterAnfExprByTensorSource(&outputs_var, + out_tensor_source.tensor_source); + } + + adt::Result MakeGetterAnfExprByTensorSource( + axpr::LetVar* in_out_tensors, const TensorSource& tensor_source) const { + using RetT = adt::Result; + return tensor_source.Match( + [&](const NativeIrValueSource& native) -> RetT { + return &in_out_tensors->At(native.native_ir_value_index); + }, + [&](const PackedIrValueSource& packed) -> RetT { + return &in_out_tensors->At(packed.packed_ir_value_index) + .At(packed.tensor_member_index); + }); + } + + adt::Result MakeRuntimeGetterImpl( + axpr::LetVar* dispatch_ctx, + const DimExprKernelArgId& kernel_arg_id) const { + ADT_CHECK(arg_source_ctx->HasDirectOrIndirectDimExprSource( + kernel_arg_id->dim_expr)); + return MakeGetterAnfExprByDimExpr(dispatch_ctx, kernel_arg_id->dim_expr); + } + + adt::Result MakeGetterAnfExprByDimSource( + axpr::LetVar* dispatch_ctx, const DimSource& dim_source) const { + using RetT = adt::Result; + return dim_source.Match( + [&](const ShapeDimSource& shape_dim_source) -> RetT { + ADT_LET_CONST_REF(tensor_var_ptr, + MakeGetterAnfExprByInOutTensorSource( + dispatch_ctx, shape_dim_source.tensor_source)); + auto* ctx = dispatch_ctx->ctx(); + auto* dim_expr = + &tensor_var_ptr->Attr("shape").At(shape_dim_source.dim_axis); + auto* data_value = &ctx->Var("DataValue").Call(*dim_expr); + auto* ret = &ctx->Var(ctx->NewTmpVarName()); + *ret = data_value->Attr("cast").Call( + ctx->Var("DataType").Attr("const_int64")); + return ret; + }, + [&](const DataDimSource& data_dim_source) -> RetT { + return adt::errors::TypeError{"DataDimSource is not supported yet."}; + }); + } + + adt::Result MakeGetterAnfExprByInOutTensorSource( + axpr::LetVar* dispatch_ctx, + const InOutTensorSource& in_out_tensor_source) const { + using RetT = adt::Result; + return in_out_tensor_source.Match( + [&](const InTensorSource& in_tensor_source) -> RetT { + return MakeGetterAnfExprByInTensorSource(dispatch_ctx, + in_tensor_source); + }, + [&](const OutTensorSource& out_tensor_source) -> RetT { + return MakeGetterAnfExprByOutTensorSource(dispatch_ctx, + out_tensor_source); + }); + } + + adt::Result MakeGetterAnfExprByDimExpr( + axpr::LetVar* dispatch_ctx, const symbol::DimExpr& dim_expr) const { + const auto& opt_dim_source = arg_source_ctx->GetDimExprSource(dim_expr); + if (opt_dim_source.has_value()) { + return MakeGetterAnfExprByDimSource(dispatch_ctx, + *opt_dim_source.value()); + } + return dim_expr.Match([&](const auto& impl) { + return MakeGetterAnfExprByDimExprImpl(dispatch_ctx, impl); + }); + } + + adt::Result MakeGetterAnfExprByDimExprImpl( + axpr::LetVar* dispatch_ctx, int64_t c) const { + auto* ctx = dispatch_ctx->ctx(); + auto* ret_var = &ctx->Var(ctx->NewTmpVarName()); + *ret_var = ctx->Int64(c); + return ret_var; + } + + adt::Result MakeGetterAnfExprByDimExprImpl( + axpr::LetVar* dispatch_ctx, const std::string& symbol) const { + return adt::errors::NotImplementedError{ + "Dead code. Symbols have been handled in MakeGetterAnfExprByDimExpr()"}; + } + + adt::Result MakeGetterAnfExprByDimExprImpl( + axpr::LetVar* dispatch_ctx, + const symbol::Negative& dim_expr) const { + return adt::errors::NotImplementedError{ + "Dead code. Negative dim_exprs have been handled in " + "MakeGetterAnfExprByDimExprImpl(dispatch_ctx, const " + "symbol::Add&)"}; + } + + adt::Result MakeGetterAnfExprByDimExprImpl( + axpr::LetVar* dispatch_ctx, + const symbol::Reciprocal& dim_expr) const { + return adt::errors::NotImplementedError{ + "Dead code. Reciprocal dim_exprs have been handled in " + "MakeGetterAnfExprByDimExprImpl(dispatch_ctx, const " + "symbol::Mul&)"}; + } + + adt::Result MakeGetterAnfExprByDimExprImpl( + axpr::LetVar* dispatch_ctx, + const symbol::Add& dim_expr) const { + const auto& [operands] = dim_expr; + ADT_CHECK(operands->size() > 0); + auto* ctx = dispatch_ctx->ctx(); + ADT_LET_CONST_REF( + init_var_ptr, + MakeGetterAnfExprByDimExpr(dispatch_ctx, operands->at(0))); + axpr::LetVar* ret_var_ptr = init_var_ptr; + for (int i = 1; i < operands->size(); ++i) { + const auto& operand = operands->at(i); + auto* tmp_var_ptr = &ctx->Var(ctx->NewTmpVarName()); + if (operand.template Has>()) { + const auto& [operand_operand] = + *operand.template Get>(); + ADT_LET_CONST_REF( + operand_operand_var_ptr, + MakeGetterAnfExprByDimExpr(dispatch_ctx, operand_operand)); + *tmp_var_ptr = ctx->Call( + axpr::kBuiltinSub(), *ret_var_ptr, *operand_operand_var_ptr); + } else { + ADT_LET_CONST_REF(operand_var_ptr, + MakeGetterAnfExprByDimExpr(dispatch_ctx, operand)); + *tmp_var_ptr = + ctx->Call(axpr::kBuiltinAdd(), *ret_var_ptr, *operand_var_ptr); + } + ret_var_ptr = tmp_var_ptr; + } + return ret_var_ptr; + } + + adt::Result MakeGetterAnfExprByDimExprImpl( + axpr::LetVar* dispatch_ctx, + const symbol::Mul& dim_expr) const { + const auto& [operands] = dim_expr; + ADT_CHECK(operands->size() > 0); + auto* ctx = dispatch_ctx->ctx(); + ADT_LET_CONST_REF( + init_var_ptr, + MakeGetterAnfExprByDimExpr(dispatch_ctx, operands->at(0))); + axpr::LetVar* ret_var_ptr = init_var_ptr; + for (int i = 1; i < operands->size(); ++i) { + const auto& operand = operands->at(i); + auto* tmp_var_ptr = &ctx->Var(ctx->NewTmpVarName()); + if (operand.template Has>()) { + const auto& [operand_operand] = + *operand.template Get>(); + ADT_LET_CONST_REF( + operand_operand_var_ptr, + MakeGetterAnfExprByDimExpr(dispatch_ctx, operand_operand)); + *tmp_var_ptr = ctx->Call( + axpr::kBuiltinDiv(), *ret_var_ptr, *operand_operand_var_ptr); + } else { + ADT_LET_CONST_REF(operand_var_ptr, + MakeGetterAnfExprByDimExpr(dispatch_ctx, operand)); + *tmp_var_ptr = + ctx->Call(axpr::kBuiltinMul(), *ret_var_ptr, *operand_var_ptr); + } + ret_var_ptr = tmp_var_ptr; + } + return ret_var_ptr; + } + + adt::Result MakeGetterAnfExprByDimExprImpl( + axpr::LetVar* dispatch_ctx, + const symbol::Max& dim_expr) const { + const auto& [operands] = dim_expr; + ADT_CHECK(operands->size() > 0); + auto* ctx = dispatch_ctx->ctx(); + ADT_LET_CONST_REF( + init_var_ptr, + MakeGetterAnfExprByDimExpr(dispatch_ctx, operands->at(0))); + axpr::LetVar* ret_var_ptr = init_var_ptr; + for (int i = 1; i < operands->size(); ++i) { + const auto& operand = operands->at(i); + auto* tmp_var_ptr = &ctx->Var(ctx->NewTmpVarName()); + ADT_LET_CONST_REF(operand_var_ptr, + MakeGetterAnfExprByDimExpr(dispatch_ctx, operand)); + *tmp_var_ptr = ctx->Call("max", *ret_var_ptr, *operand_var_ptr); + ret_var_ptr = tmp_var_ptr; + } + return ret_var_ptr; + } + + adt::Result MakeGetterAnfExprByDimExprImpl( + axpr::LetVar* dispatch_ctx, + const symbol::Min& dim_expr) const { + const auto& [operands] = dim_expr; + ADT_CHECK(operands->size() > 0); + auto* ctx = dispatch_ctx->ctx(); + ADT_LET_CONST_REF( + init_var_ptr, + MakeGetterAnfExprByDimExpr(dispatch_ctx, operands->at(0))); + axpr::LetVar* ret_var_ptr = init_var_ptr; + for (int i = 1; i < operands->size(); ++i) { + const auto& operand = operands->at(i); + auto* tmp_var_ptr = &ctx->Var(ctx->NewTmpVarName()); + ADT_LET_CONST_REF(operand_var_ptr, + MakeGetterAnfExprByDimExpr(dispatch_ctx, operand)); + *tmp_var_ptr = ctx->Call("min", *ret_var_ptr, *operand_var_ptr); + ret_var_ptr = tmp_var_ptr; + } + return ret_var_ptr; + } + + adt::Result MakeGetterAnfExprByDimExprImpl( + axpr::LetVar* dispatch_ctx, + const symbol::Broadcast& dim_expr) const { + const auto& [operands] = dim_expr; + ADT_CHECK(operands->size() > 0); + auto* ctx = dispatch_ctx->ctx(); + ADT_LET_CONST_REF( + init_var_ptr, + MakeGetterAnfExprByDimExpr(dispatch_ctx, operands->at(0))); + axpr::LetVar* ret_var_ptr = init_var_ptr; + for (int i = 1; i < operands->size(); ++i) { + const auto& operand = operands->at(i); + auto* tmp_var_ptr = &ctx->Var(ctx->NewTmpVarName()); + ADT_LET_CONST_REF(operand_var_ptr, + MakeGetterAnfExprByDimExpr(dispatch_ctx, operand)); + *tmp_var_ptr = ctx->Call("max", *ret_var_ptr, *operand_var_ptr); + ret_var_ptr = tmp_var_ptr; + } + return ret_var_ptr; + } +}; + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/arg_source_maker.h b/paddle/ap/include/code_gen/arg_source_maker.h new file mode 100644 index 00000000000000..cb10c91e19d231 --- /dev/null +++ b/paddle/ap/include/code_gen/arg_source_maker.h @@ -0,0 +1,168 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/code_gen/arg_source_ctx.h" +#include "paddle/ap/include/code_gen/matched_result_pattern_helper.h" + +namespace ap::code_gen { + +template +struct ArgSourceMaker { + const code_gen::MatchedResultPatternHelper& matched_res_ptn_helper; + + using DrrValue = drr::Value; + using DrrNode = drr::Node; + using DrrPackedIrOp = drr::PackedIrOp; + + adt::Result> MakeArgSourceCtx( + const DrrPackedIrOp& res_ptn_ir_op) const { + ADT_LET_CONST_REF(input_and_tensor_source_pairs, + MakeInputAndTensorSourcePairs(res_ptn_ir_op)); + ADT_LET_CONST_REF(output_and_tensor_source_pairs, + MakeOutputAndTensorSourcePairs(res_ptn_ir_op)); + std::vector> + dim_expr_and_dim_source_pairs; + ADT_RETURN_IF_ERR(CollectDimExprAndDimSourcePairs( + &dim_expr_and_dim_source_pairs, input_and_tensor_source_pairs)); + ADT_RETURN_IF_ERR(CollectDimExprAndDimSourcePairs( + &dim_expr_and_dim_source_pairs, output_and_tensor_source_pairs)); + std::unordered_map dim_expr2dim_source{ + dim_expr_and_dim_source_pairs.begin(), + dim_expr_and_dim_source_pairs.end()}; + return ArgSourceCtx{input_and_tensor_source_pairs, + output_and_tensor_source_pairs, + dim_expr_and_dim_source_pairs, + dim_expr2dim_source}; + } + + private: + adt::Result>> + MakeInputAndTensorSourcePairs(const DrrPackedIrOp& res_ptn_ir_op) const { + using Ok = adt::Result; + std::vector inputs; + { + auto CollectInput = [&](const BirNode& bir_node) -> Ok { + inputs.emplace_back(bir_node); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + matched_res_ptn_helper.VisitMatchedBirInputOfRestPtnPackedIrOp( + res_ptn_ir_op, CollectInput)); + } + using Pair = std::pair; + std::vector ret; + ret.reserve(inputs.size()); + { + std::size_t input_idx = 0; + auto DoEachIndex = [&](std::size_t index) -> Ok { + ADT_CHECK(index < inputs.size()); + const auto& bir_node = inputs.at(index); + NativeIrValueSource native_ir_value_source{input_idx}; + TensorSource tensor_source{native_ir_value_source}; + InTensorSource in_tensor_source{tensor_source}; + Pair pair{bir_node, in_tensor_source}; + ret.emplace_back(pair); + ++input_idx; + return adt::Ok{}; + }; + auto DoEachSlice = [&](std::size_t start, std::size_t end) -> Ok { + for (std::size_t i = start; i < end; ++i) { + ADT_CHECK(i < inputs.size()); + const auto& bir_node = inputs.at(i); + PackedIrValueSource packed_ir_value_source{input_idx, i}; + TensorSource tensor_source{packed_ir_value_source}; + InTensorSource in_tensor_source{tensor_source}; + Pair pair{bir_node, in_tensor_source}; + ret.emplace_back(pair); + } + ++input_idx; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(matched_res_ptn_helper.VisitApKernelInputIndexOrSlice( + res_ptn_ir_op, DoEachIndex, DoEachSlice)); + } + return ret; + } + + adt::Result>> + MakeOutputAndTensorSourcePairs(const DrrPackedIrOp& res_ptn_ir_op) const { + using Ok = adt::Result; + std::vector outputs; + { + auto CollectOutput = [&](const BirNode& bir_node) -> Ok { + outputs.emplace_back(bir_node); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + matched_res_ptn_helper.VisitMatchedBirOutputOfRestPtnPackedIrOp( + res_ptn_ir_op, CollectOutput)); + } + using Pair = std::pair; + std::vector ret; + ret.reserve(outputs.size()); + { + std::size_t output_idx = 0; + auto DoEachIndex = [&](std::size_t index) -> Ok { + ADT_CHECK(index < outputs.size()); + const auto& bir_node = outputs.at(index); + OutTensorSource out_tensor_source{ + TensorSource{NativeIrValueSource{output_idx}}}; + ret.emplace_back(Pair{bir_node, out_tensor_source}); + ++output_idx; + return adt::Ok{}; + }; + auto DoEachSlice = [&](std::size_t start, std::size_t end) -> Ok { + for (std::size_t i = start; i < end; ++i) { + ADT_CHECK(i < outputs.size()); + const auto& bir_node = outputs.at(i); + OutTensorSource out_tensor_source{ + TensorSource{PackedIrValueSource{output_idx, i}}}; + ret.emplace_back(Pair{bir_node, out_tensor_source}); + } + ++output_idx; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(matched_res_ptn_helper.VisitApKernelOutputIndexOrSlice( + res_ptn_ir_op, DoEachIndex, DoEachSlice)); + } + return ret; + } + + template + adt::Result CollectDimExprAndDimSourcePairs( + std::vector>* + dim_expr_and_dim_source_pairs, + const std::vector>& tensor_and_sources) + const { + for (const auto& [bir_node, tensor_source] : tensor_and_sources) { + ADT_LET_CONST_REF( + bir_value, matched_res_ptn_helper.CastToBirNativeIrValue(bir_node)); + ADT_LET_CONST_REF(dim_exprs_ptr, bir_value.GetShapeDimExprsPtr()); + for (int i = 0; i < dim_exprs_ptr->size(); ++i) { + const auto& dim_expr = dim_exprs_ptr->at(i); + using Pair = std::pair; + DimSource dim_source{ShapeDimSource{tensor_source, i}}; + Pair pair{dim_expr, dim_source}; + dim_expr_and_dim_source_pairs->emplace_back(pair); + } + } + return adt::Ok{}; + } +}; + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/builtin_frame_util.h b/paddle/ap/include/code_gen/builtin_frame_util.h new file mode 100644 index 00000000000000..12c4a98e49854b --- /dev/null +++ b/paddle/ap/include/code_gen/builtin_frame_util.h @@ -0,0 +1,54 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_frame_util.h" +#include "paddle/ap/include/axpr/dim_expr_method_class.h" +#include "paddle/ap/include/code_gen/code_gen_result_method_class.h" +#include "paddle/ap/include/code_module/code_module_method_class.h" +#include "paddle/ap/include/code_module/func_declare_method_class.h" +#include "paddle/ap/include/code_module/package_method_class.h" +#include "paddle/ap/include/code_module/project_method_class.h" +#include "paddle/ap/include/index_expr/index_expr_method_class.h" +#include "paddle/ap/include/index_expr/index_tuple_expr_method_class.h" +#include "paddle/ap/include/index_expr/slice_method_class.h" + +namespace ap::code_gen { + +template +void VisitEachBuiltinFrameClass(const DoEachT& DoEach) { + DoEach(code_module::GetProjectClass()); + DoEach(code_module::GetPackageClass()); + DoEach(code_module::GetFuncDeclareClass()); + DoEach(code_module::GetCodeModuleClass()); + DoEach(axpr::GetDimExprClass()); + DoEach(index_expr::GetSliceClass()); + DoEach(index_expr::GetIndexExprClass()); + DoEach(index_expr::GetIndexTupleExprClass()); + DoEach(GetCodeGenResultClass()); +} + +template +axpr::AttrMap MakeBuiltinFrameAttrMap() { + axpr::AttrMap attr_map; + axpr::VisitEachBuiltinFrameAttr( + [&](const std::string& k, const ValueT& v) { attr_map->Set(k, v); }); + VisitEachBuiltinFrameClass( + [&](const auto& cls) { attr_map->Set(cls.Name(), cls); }); + return attr_map; +} + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/code_gen_ctx.h b/paddle/ap/include/code_gen/code_gen_ctx.h new file mode 100644 index 00000000000000..bb87bbe6d14368 --- /dev/null +++ b/paddle/ap/include/code_gen/code_gen_ctx.h @@ -0,0 +1,50 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/code_gen/arg_source_ctx.h" +#include "paddle/ap/include/code_module/adt.h" +#include "paddle/ap/include/code_module/arg_type.h" +#include "paddle/ap/include/code_module/data_type.h" +#include "paddle/ap/include/drr/value.h" +#include "paddle/ap/include/ir_match/ir_match_ctx.h" + +namespace ap::code_gen { + +template +struct CodeGenCtxImpl { + std::optional> ir_match_ctx; + + using DrrNode = drr::Node; + using DrrPackedIrOp = drr::PackedIrOp; + + DrrPackedIrOp res_ptn_ir_op; + + ArgSourceCtx arg_source_ctx; + + bool operator==(const CodeGenCtxImpl& other) const { return this == &other; } +}; + +template +ADT_DEFINE_RC(CodeGenCtx, CodeGenCtxImpl); + +template +axpr::TypeImpl> GetCodeGenCtxClass(); + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/code_gen_ctx_method_class.h b/paddle/ap/include/code_gen/code_gen_ctx_method_class.h new file mode 100644 index 00000000000000..b46930bb8a0cc0 --- /dev/null +++ b/paddle/ap/include/code_gen/code_gen_ctx_method_class.h @@ -0,0 +1,290 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/packed_args.h" +#include "paddle/ap/include/code_gen/arg_source_helper.h" +#include "paddle/ap/include/code_gen/cuda_code_gen_util.h" +#include "paddle/ap/include/code_gen/dim_expr_kernel_arg_id_method_class.h" +#include "paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id_method_class.h" +#include "paddle/ap/include/code_gen/ir_op.h" +#include "paddle/ap/include/code_gen/kernel_arg_id_helper.h" +#include "paddle/ap/include/code_gen/op_code_gen_ctx.h" +#include "paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id_method_class.h" +#include "paddle/ap/include/code_module/code_module.h" +#include "paddle/ap/include/index_expr/index_tuple_expr.h" +#include "paddle/ap/include/ir_match/native_or_ref_ir_value.h" +#include "paddle/ap/include/registry/registry_singleton.h" + +namespace ap::code_gen { + +using ap::axpr::BuiltinBinaryFunc; +using ap::axpr::BuiltinFuncType; +using ap::axpr::BuiltinUnaryFunc; +using ap::axpr::CppDataType; +using ap::axpr::CppPointerType; +using ap::axpr::DataType; +using ap::axpr::MethodClass; +using ap::axpr::PointerType; + +template +struct CodeGenCtxMethodClass { + using This = CodeGenCtxMethodClass; + using Self = CodeGenCtx; + + static adt::Result StaticMakeAndCheckOutTensorDataPtrKernelArgId( + const ValueT& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return This{}.MakeAndCheckOutTensorDataPtrKernelArgId(self, args); + } + + adt::Result MakeAndCheckOutTensorDataPtrKernelArgId( + const Self& self, const std::vector& args) { + ADT_CHECK(args.size() == 1) + << adt::errors::TypeError{std::string() + + "out_tensor_data_ptr_kernel_" + "arg_id() takes 1 argument but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(ir_value, CastToBirValue(args.at(0))) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of " + "out_tensor_data_ptr_kernel_arg_id() should be " + "'NativeIrValue' or 'RefIrValue' (not '" + + axpr::GetTypeName(args.at(0)) + "')."}; + ADT_RETURN_IF_ERR(CheckOutTensorDataPtrRuntimeAvailable(self, ir_value)); + OutTensorDataPtrKernelArgId uninitialized{ir_value, std::nullopt}; + ArgSourceHelper helper{self->arg_source_ctx}; + ADT_LET_CONST_REF(runtime_getter, + helper.MakeRuntimeKerneArgGetter(uninitialized)); + OutTensorDataPtrKernelArgId kernel_arg_id{ir_value, + runtime_getter}; + axpr::BuiltinClassInstance instance{ + GetOutTensorDataPtrKernelArgIdClass(), kernel_arg_id}; + return instance; + } + + adt::Result CheckOutTensorDataPtrRuntimeAvailable( + const Self& self, const BirNode& ir_value) { + ADT_CHECK(self->arg_source_ctx->GetOutputTensorSource(ir_value).has_value()) + << adt::errors::TypeError{ + std::string() + + "out_tensor_data_ptr_kernel_arg_id() failed. " + "please check whether the ir_value is an output value of the " + "current ap_pattern_fusion_op defined in drr result pattern " + "lambda."}; + return adt::Ok{}; + } + + static adt::Result StaticMakeAndCheckInTensorDataPtrKernelArgId( + const ValueT& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return This{}.MakeAndCheckInTensorDataPtrKernelArgId(self, args); + } + + adt::Result MakeAndCheckInTensorDataPtrKernelArgId( + const Self& self, const std::vector& args) { + ADT_CHECK(args.size() == 1) + << adt::errors::TypeError{std::string() + + "in_tensor_data_ptr_kernel_" + "arg_id() takes 1 argument but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(ir_value, CastToBirValue(args.at(0))) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of " + "in_tensor_data_ptr_kernel_arg_id() should be " + "'NativeIrValue' or 'RefIrValue' (not '" + + axpr::GetTypeName(args.at(0)) + "')."}; + ADT_RETURN_IF_ERR(CheckInTensorDataPtrRuntimeAvailable(self, ir_value)); + InTensorDataPtrKernelArgId uninitialized{ir_value, std::nullopt}; + ArgSourceHelper helper{self->arg_source_ctx}; + ADT_LET_CONST_REF(runtime_getter, + helper.MakeRuntimeKerneArgGetter(uninitialized)); + InTensorDataPtrKernelArgId kernel_arg_id{ir_value, runtime_getter}; + axpr::BuiltinClassInstance instance{ + GetInTensorDataPtrKernelArgIdClass(), kernel_arg_id}; + return instance; + } + + adt::Result CheckInTensorDataPtrRuntimeAvailable( + const Self& self, const BirNode& ir_value) { + ADT_CHECK(self->arg_source_ctx->GetInputTensorSource(ir_value).has_value()) + << adt::errors::TypeError{ + std::string() + + "in_tensor_data_ptr_kernel_arg_id() failed. " + "please check whether the ir_value is an input value of the " + "current ap_pattern_fusion_op defined in drr result pattern " + "lambda."}; + return adt::Ok{}; + } + + adt::Result CastToBirValue(const ValueT& val) { + ADT_LET_CONST_REF( + instance, val.template CastTo>()); + if (instance.template Has()) { + ADT_LET_CONST_REF( + ret, instance.template TryGet()); + return ret; + } + if (instance.template Has()) { + ADT_LET_CONST_REF( + ret, instance.template TryGet()); + return ret; + } + return adt::errors::NotImplementedError{ + std::string() + + "CastToBirValue() failed. only 'NativeIrValue' and 'RefIrValue' " + "argument is expected, but '" + + axpr::GetTypeName(val) + "' found."}; + } + + static adt::Result StaticMakeAndCheckDimExprKernelArgId( + const ValueT& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return This{}.MakeAndCheckDimExprKernelArgId(self, args); + } + + adt::Result MakeAndCheckDimExprKernelArgId( + const Self& self, const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "dim_expr_kernel_arg_id() takes 1 arguments but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(dim_expr, args.at(0).template CastTo()) + << adt::errors::TypeError{std::string() + + "the argument 1 of dim_expr_kernel_arg_id() " + "should be 'DimExpr' (not '" + + axpr::GetTypeName(args.at(0)) + "')."}; + ADT_RETURN_IF_ERR(CheckDimExprRuntimeAvailable(self, dim_expr)); + DimExprKernelArgId uninitialized{dim_expr, std::nullopt}; + ArgSourceHelper helper{self->arg_source_ctx}; + ADT_LET_CONST_REF(runtime_getter, + helper.MakeRuntimeKerneArgGetter(uninitialized)); + DimExprKernelArgId kernel_arg_id{dim_expr, runtime_getter}; + axpr::BuiltinClassInstance instance{ + GetDimExprKernelArgIdClass(), kernel_arg_id}; + return instance; + } + + adt::Result CheckDimExprRuntimeAvailable( + const Self& self, const symbol::DimExpr& dim_expr) { + ADT_CHECK(self->arg_source_ctx->HasDirectOrIndirectDimExprSource(dim_expr)) + << adt::errors::ValueError{ + std::string() + + "DimExpr could not evaluated in runtime. value: " + + symbol::ToString(dim_expr)}; + return adt::Ok{}; + } + + static adt::Result StaticMakeFusionOpCodeGenClass( + const ValueT& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return This{}.MakeFusionOpCodeGenClass(self, args); + } + + using NativeOrRefIrValue = ir_match::NativeOrRefIrValue; + + adt::Result MakeFusionOpCodeGenClass( + const Self& self, const std::vector& packed_args_vec) { + const auto& packed_args = axpr::CastToPackedArgs(packed_args_vec); + const auto& [args, kwargs] = *packed_args; + ADT_CHECK(args->size() == 1) << adt::errors::TypeError{ + "'CodeGenCtx.make_fusion_op_code_gen_class' takes 1 positional " + "arguments but " + + std::to_string(args->size()) + " were given."}; + ADT_LET_CONST_REF(ir_op, IrOp::CastFrom(args->at(0))) + << adt::errors::TypeError{ + std::string() + + "the positional argument 1 of " + "'CodeGenCtx.make_fusion_op_code_gen_class' should " + "be able to cast to a NativeIrOp, PackedIrOp or RefIrOp."}; + ADT_LET_CONST_REF(input_index_loop_anchor_flags_lst, + kwargs->template Get>( + "input_index_loop_anchor_flags")) + << adt::errors::TypeError{ + std::string() + + "'CodeGenCtx.input_index_loop_anchor_flags' requires bool list " + "typed " + "keyword argument 'input_index_loop_anchor_flags'."}; + LoopAnchorFlags input_index_loop_anchor_flags; + { + input_index_loop_anchor_flags->reserve( + input_index_loop_anchor_flags_lst->size()); + for (const auto& elt : *input_index_loop_anchor_flags_lst) { + ADT_LET_CONST_REF(mask, elt.template CastTo()) + << adt::errors::TypeError{ + std::string() + + "'CodeGenCtx.input_index_loop_anchor_flags' requires bool " + "list typed " + "keyword argument 'input_index_loop_anchor_flags'."}; + input_index_loop_anchor_flags->emplace_back( + tLoopAnchorFlag{mask}); + } + } + ADT_LET_CONST_REF(output_index_loop_anchor_flags_lst, + kwargs->template Get>( + "output_index_loop_anchor_flags")) + << adt::errors::TypeError{ + std::string() + + "'CodeGenCtx.output_index_loop_anchor_flags' requires bool list " + "typed " + "keyword argument 'output_index_loop_anchor_flags'."}; + LoopAnchorFlags output_index_loop_anchor_flags; + { + output_index_loop_anchor_flags->reserve( + output_index_loop_anchor_flags_lst->size()); + for (const auto& elt : *output_index_loop_anchor_flags_lst) { + ADT_LET_CONST_REF(mask, elt.template CastTo()) + << adt::errors::TypeError{ + std::string() + + "'CodeGenCtx.output_index_loop_anchor_flags' requires bool " + "list typed " + "keyword argument 'output_index_loop_anchor_flags'."}; + output_index_loop_anchor_flags->emplace_back( + tLoopAnchorFlag{mask}); + } + } + + OpCodeGenCtx op_code_gen_ctx{self.shared_ptr(), + input_index_loop_anchor_flags, + output_index_loop_anchor_flags}; + ADT_LET_CONST_REF( + class_attrs, + ConvertFusionOpToClassAttrs(op_code_gen_ctx, ir_op)); + return axpr::TypeImpl>(class_attrs); + } +}; + +template +axpr::TypeImpl> GetCodeGenCtxClass() { + using ImplMethods = CodeGenCtxMethodClass; + static auto cls( + axpr::MakeBuiltinClass("CodeGenCtx", [&](const auto& Define) { + Define("make_fusion_op_code_gen_class", + &ImplMethods::StaticMakeFusionOpCodeGenClass); + Define("dim_expr_kernel_arg_id", + &ImplMethods::StaticMakeAndCheckDimExprKernelArgId); + Define("in_tensor_data_ptr_kernel_arg_id", + &ImplMethods::StaticMakeAndCheckInTensorDataPtrKernelArgId); + Define("out_tensor_data_ptr_kernel_arg_id", + &ImplMethods::StaticMakeAndCheckOutTensorDataPtrKernelArgId); + })); + using Self = typename ImplMethods::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/code_gen_result.h b/paddle/ap/include/code_gen/code_gen_result.h new file mode 100644 index 00000000000000..4cf60f339b8ee3 --- /dev/null +++ b/paddle/ap/include/code_gen/code_gen_result.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/builtin_serializable_attr_map.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/code_module/code_module.h" + +namespace ap::code_gen { + +template +struct CodeGenResultImpl { + code_module::CodeModule code_module; + axpr::Function kernel_dispatch_func; + axpr::AttrMap kernel_dispatch_const_data; + + bool operator==(const CodeGenResultImpl& other) const { + return this == &other; + } +}; + +template +ADT_DEFINE_RC(CodeGenResult, CodeGenResultImpl); + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/code_gen_result_method_class.h b/paddle/ap/include/code_gen/code_gen_result_method_class.h new file mode 100644 index 00000000000000..7d0f39aa21fb43 --- /dev/null +++ b/paddle/ap/include/code_gen/code_gen_result_method_class.h @@ -0,0 +1,25 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/code_gen/code_gen_result.h" + +namespace ap::code_gen { + +axpr::TypeImpl> GetCodeGenResultClass(); + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/cuda_code_gen_util.h b/paddle/ap/include/code_gen/cuda_code_gen_util.h new file mode 100644 index 00000000000000..e3241ed07b7ba5 --- /dev/null +++ b/paddle/ap/include/code_gen/cuda_code_gen_util.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/code_gen/op_cuda_gen_impl.h" +#include "paddle/ap/include/code_gen/value.h" + +namespace ap::code_gen { + +template +adt::Result> +ConvertFusionOpToClassAttrs(const OpCodeGenCtx& op_code_gen_ctx, + const IrOp& ir_op) { + OpCudaCodeGenImpl impl{}; + ADT_LET_CONST_REF(class_attrs, + impl.ConvertFusionOpToClassAttrs(op_code_gen_ctx, ir_op)); + return class_attrs; +} + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/dim_expr_kernel_arg_id.h b/paddle/ap/include/code_gen/dim_expr_kernel_arg_id.h new file mode 100644 index 00000000000000..8b14f8c9c452fd --- /dev/null +++ b/paddle/ap/include/code_gen/dim_expr_kernel_arg_id.h @@ -0,0 +1,53 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" + +namespace ap::code_gen { + +template +struct DimExprKernelArgIdImpl { + symbol::DimExpr dim_expr; + std::optional> runtime_getter; + + bool operator==(const DimExprKernelArgIdImpl& other) const { + return this->dim_expr == other.dim_expr; + } + + template + adt::Result CastData() const { + axpr::BuiltinClassInstance instance{axpr::GetDimExprClass(), + this->dim_expr}; + return ValueT{instance}; + } + + std::size_t GetHashValue() const { + return std::hash()(this->dim_expr); + } +}; + +template +ADT_DEFINE_RC(DimExprKernelArgId, DimExprKernelArgIdImpl); + +template +axpr::TypeImpl> GetDimExprKernelArgIdClass(); + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/dim_expr_kernel_arg_id_method_class.h b/paddle/ap/include/code_gen/dim_expr_kernel_arg_id_method_class.h new file mode 100644 index 00000000000000..5953632cda17e5 --- /dev/null +++ b/paddle/ap/include/code_gen/dim_expr_kernel_arg_id_method_class.h @@ -0,0 +1,88 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/code_gen/code_gen_ctx.h" +#include "paddle/ap/include/code_gen/dim_expr_kernel_arg_id.h" +#include "paddle/ap/include/code_gen/kernel_arg_id_helper.h" + +namespace ap::code_gen { + +template +struct DimExprKernelArgIdMethodClass { + using This = DimExprKernelArgIdMethodClass; + using Self = DimExprKernelArgId; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + ss << self->dim_expr; + return ss.str(); + } + + static adt::Result Hash(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::size_t hash_value = std::hash()(self->dim_expr); + return static_cast(hash_value); + } + + static adt::Result GetAttr(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + const auto& attr_name_val = args.at(0); + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + if (attr_name == "value") { + return self->template CastData(); + } + if (attr_name == "type") { + return This{}.GetArgType(self); + } + if (attr_name == "runtime_getter") { + ADT_CHECK(self->runtime_getter.has_value()) + << adt::errors::ValueError{"no runtime getter initialized"}; + return self->runtime_getter.value(); + } + return adt::errors::AttributeError{ + std::string() + "'DimExprKernelArgId' instance has no attribute '" + + attr_name + "'."}; + } + + adt::Result GetArgType(const Self& self) { + KernelArgIdHelper helper; + ADT_LET_CONST_REF(arg_type, helper.GetArgType(self)); + return arg_type.template CastTo(); + } +}; + +template +axpr::TypeImpl> +GetDimExprKernelArgIdClass() { + using ImplMethods = DimExprKernelArgIdMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "DimExprKernelArgId", [&](const auto& Define) { + Define("__str__", &ImplMethods::ToString); + Define("__hash__", &ImplMethods::Hash); + Define("__getattr__", &ImplMethods::GetAttr); + })); + using Self = typename ImplMethods::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id.h b/paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id.h new file mode 100644 index 00000000000000..e526eb73a145c5 --- /dev/null +++ b/paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id.h @@ -0,0 +1,67 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::code_gen { + +template +struct InTensorDataPtrKernelArgIdImpl { + BirNode ir_value; + std::optional> runtime_getter; + + bool operator==(const InTensorDataPtrKernelArgIdImpl& other) const { + return this->ir_value == other.ir_value; + } + + template + adt::Result CastData() const { + using RetT = adt::Result; + return this->ir_value.Match( + [&](const typename BirNode::native_value_type& impl) -> RetT { + return impl; + }, + [&](const typename BirNode::ref_value_type& impl) -> RetT { + return impl; + }, + [&](const auto& impl) -> RetT { + using T = std::decay_t; + return adt::errors::NotImplementedError{ + std::string() + + "CastData() failed, only NativeIrValue and RefIrValue supported, " + "but '" + + typeid(T).name() + "' found."}; + }); + } + + std::size_t GetHashValue() const { + return std::hash()(this->ir_value); + } +}; + +template +ADT_DEFINE_RC(InTensorDataPtrKernelArgId, + InTensorDataPtrKernelArgIdImpl); + +template +axpr::TypeImpl> +GetInTensorDataPtrKernelArgIdClass(); + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id_method_class.h b/paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id_method_class.h new file mode 100644 index 00000000000000..3dcb4d3197f888 --- /dev/null +++ b/paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id_method_class.h @@ -0,0 +1,86 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id.h" +#include "paddle/ap/include/code_gen/kernel_arg_id_helper.h" + +namespace ap::code_gen { + +template +struct InTensorDataPtrKernelArgIdMethodClass { + using This = InTensorDataPtrKernelArgIdMethodClass; + using Self = InTensorDataPtrKernelArgId; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << ""; + return ss.str(); + } + + static adt::Result Hash(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::size_t hash_value = self->ir_value.GetHashValue(); + return static_cast(hash_value); + } + + static adt::Result GetAttr(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + const auto& attr_name_val = args.at(0); + ADT_LET_CONST_REF(attr_name, attr_name_val.template CastTo()); + if (attr_name == "type") { + return This{}.GetArgType(self); + } + if (attr_name == "runtime_getter") { + ADT_CHECK(self->runtime_getter.has_value()) + << adt::errors::ValueError{"no runtime getter initialized"}; + return self->runtime_getter.value(); + } + return adt::errors::AttributeError{ + std::string() + + "'InTensorDataPtrKernelArgId' instance has no attribute '" + attr_name + + "'."}; + } + + adt::Result GetArgType(const Self& self) { + KernelArgIdHelper helper; + ADT_LET_CONST_REF(arg_type, helper.GetArgType(self)); + return arg_type.template CastTo(); + } +}; + +template +axpr::TypeImpl> +GetInTensorDataPtrKernelArgIdClass() { + using ImplMethods = InTensorDataPtrKernelArgIdMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "InTensorDataPtrKernelArgId", [&](const auto& Define) { + Define("__str__", &ImplMethods::ToString); + Define("__hash__", &ImplMethods::Hash); + Define("__getattr__", &ImplMethods::GetAttr); + })); + using Self = typename ImplMethods::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/ir_op.h b/paddle/ap/include/code_gen/ir_op.h new file mode 100644 index 00000000000000..fd542654a2f6e2 --- /dev/null +++ b/paddle/ap/include/code_gen/ir_op.h @@ -0,0 +1,55 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" + +namespace ap::code_gen { + +template +using IrOpImpl = std::variant; + +template +struct IrOp : public IrOpImpl { + using IrOpImpl::IrOpImpl; + ADT_DEFINE_VARIANT_METHODS(IrOpImpl); + + template + static adt::Result CastFrom(const ValueT& val) { + ADT_LET_CONST_REF( + instance, val.template CastTo>()); + if (instance.template Has()) { + ADT_LET_CONST_REF( + ret, instance.template TryGet()); + return ret; + } + if (instance.template Has()) { + ADT_LET_CONST_REF( + ret, instance.template TryGet()); + return ret; + } + if (instance.template Has()) { + ADT_LET_CONST_REF( + ret, instance.template TryGet()); + return ret; + } + return adt::errors::ValueError{"IrOp::CastFrom failed."}; + } +}; + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/kernel_arg.h b/paddle/ap/include/code_gen/kernel_arg.h new file mode 100644 index 00000000000000..1621bcc456591a --- /dev/null +++ b/paddle/ap/include/code_gen/kernel_arg.h @@ -0,0 +1,36 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/code_gen/kernel_arg_id.h" +#include "paddle/ap/include/code_module/adt.h" + +namespace ap::code_gen { + +struct KernelArgImpl { + KernelArgId kernel_arg_id; + axpr::Lambda getter_lambda; + + bool operator==(const KernelArgImpl& other) const { + return this->kernel_arg_id == other.kernel_arg_id && + this->getter_lambda == other.getter_lambda; + } +}; + +ADT_DEFINE_RC(KernelArg, KernelArgImpl); + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/kernel_arg_id.h b/paddle/ap/include/code_gen/kernel_arg_id.h new file mode 100644 index 00000000000000..895cf0506c2576 --- /dev/null +++ b/paddle/ap/include/code_gen/kernel_arg_id.h @@ -0,0 +1,75 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/code_gen/dim_expr_kernel_arg_id.h" +#include "paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id.h" +#include "paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id.h" + +namespace ap::code_gen { + +template +using KernelArgIdImpl = std::variant, + InTensorDataPtrKernelArgId, + OutTensorDataPtrKernelArgId>; + +template +struct KernelArgId : public KernelArgIdImpl { + using KernelArgIdImpl::KernelArgIdImpl; + + ADT_DEFINE_VARIANT_METHODS(KernelArgIdImpl); + + template + ValueT CastTo() const { + return Match([](const auto& impl) -> ValueT { return impl; }); + } + + template + static adt::Result CastFrom(const ValueT& val) { + using RetT = adt::Result; + return val.Match( + [](const DimExprKernelArgId& impl) -> RetT { return impl; }, + [](const InTensorDataPtrKernelArgId& impl) -> RetT { + return impl; + }, + [](const OutTensorDataPtrKernelArgId& impl) -> RetT { + return impl; + }, + [](const auto& impl) -> RetT { + return adt::errors::TypeError{"KernelArgId::CastFrom() failed."}; + }); + } + + std::size_t GetHashValue() const { + std::size_t hash_value = Match( + [&](const auto& impl) -> std::size_t { return impl->GetHashValue(); }); + return adt::hash_combine(this->index(), hash_value); + } +}; + +} // namespace ap::code_gen + +namespace std { + +template +struct hash> { + std::size_t operator()( + const ap::code_gen::KernelArgId& arg_id) const { + return arg_id.GetHashValue(); + } +}; + +} // namespace std diff --git a/paddle/ap/include/code_gen/kernel_arg_id_helper.h b/paddle/ap/include/code_gen/kernel_arg_id_helper.h new file mode 100644 index 00000000000000..9aa14350bdf7b9 --- /dev/null +++ b/paddle/ap/include/code_gen/kernel_arg_id_helper.h @@ -0,0 +1,87 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/pointer_type_util.h" +#include "paddle/ap/include/code_gen/kernel_arg_id.h" +#include "paddle/ap/include/code_module/arg_type.h" + +namespace ap::code_gen { + +template +struct KernelArgIdHelper { + using BirNativeIrValue = typename BirNode::native_value_type; + + template + adt::Result> CastToKernelArgId(const ValueT& val) { + using RetT = adt::Result>; + return val.Match( + [&](const DimExprKernelArgId& impl) -> RetT { return impl; }, + [&](const InTensorDataPtrKernelArgId& impl) -> RetT { + return impl; + }, + [&](const OutTensorDataPtrKernelArgId& impl) -> RetT { + return impl; + }, + [&](const auto& impl) -> RetT { + return adt::errors::TypeError{ + std::string() + + "only DimExprKernelArgId, InTensorDataPtrKernelArgId and " + "OutTensorDataPtrKernelArgId (not including '" + + axpr::GetTypeName(val) + "') can be cast to KernelArgId"}; + }); + } + + adt::Result GetArgType( + const KernelArgId& arg_id) { + using RetT = adt::Result; + return arg_id.Match( + [](const DimExprKernelArgId&) -> RetT { + return axpr::CppDataType(); + }, + [&](const InTensorDataPtrKernelArgId& in_data_ptr) -> RetT { + ADT_LET_CONST_REF(ir_value, + GetBirNativeIrValue(in_data_ptr->ir_value)); + ADT_LET_CONST_REF(data_type, ir_value.GetDataType()); + return axpr::GetConstPointerTypeFromDataType(data_type); + }, + [&](const OutTensorDataPtrKernelArgId& out_data_ptr) -> RetT { + ADT_LET_CONST_REF(ir_value, + GetBirNativeIrValue(out_data_ptr->ir_value)); + ADT_LET_CONST_REF(data_type, ir_value.GetDataType()); + return axpr::GetMutablePointerTypeFromDataType(data_type); + }); + } + + adt::Result GetBirNativeIrValue( + const BirNode& bir_node) const { + using RetT = adt::Result; + return bir_node.Match( + [&](const BirNativeIrValue& impl) -> RetT { return impl; }, + [&](const typename BirNode::ref_value_type& impl) -> RetT { + return impl.GetOwnerNativeIrValue(); + }, + [&](const auto& impl) -> RetT { + using T = std::decay_t; + return adt::errors::NotImplementedError{ + std::string() + + "GetBirNativeIrValue() failed. only 'NativeIrValue' and " + "'RefIrValue' argument expected, but '" + + typeid(T).name() + "' found."}; + }); + } +}; + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/loop_anchor_flags.h b/paddle/ap/include/code_gen/loop_anchor_flags.h new file mode 100644 index 00000000000000..b3dff32121cbe5 --- /dev/null +++ b/paddle/ap/include/code_gen/loop_anchor_flags.h @@ -0,0 +1,25 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" + +namespace ap::code_gen { + +ADT_DEFINE_TAG(tLoopAnchorFlag); + +using LoopAnchorFlags = adt::List>; + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/matched_result_pattern_helper.h b/paddle/ap/include/code_gen/matched_result_pattern_helper.h new file mode 100644 index 00000000000000..f2402430ebe895 --- /dev/null +++ b/paddle/ap/include/code_gen/matched_result_pattern_helper.h @@ -0,0 +1,235 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/atomic.h" +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/code_gen/arg_source_ctx.h" +#include "paddle/ap/include/drr/drr_graph_descriptor.h" +#include "paddle/ap/include/drr/drr_node_descriptor.h" +#include "paddle/ap/include/drr/res_ptn_packed_ir_op_declare_data.h" +#include "paddle/ap/include/drr/result_pattern_helper.h" +#include "paddle/ap/include/drr/value.h" +#include "paddle/ap/include/graph/graph_helper.h" +#include "paddle/ap/include/index_expr/valid_index_expr_builder.h" +#include "paddle/ap/include/ir_match/graph_match_ctx.h" +#include "paddle/ap/include/ir_match/graph_matcher.h" +#include "paddle/ap/include/ir_match/ir_match_ctx.h" + +namespace ap::code_gen { + +template +struct MatchedResultPatternHelper { + using DrrNode = drr::Node; + + using DrrCtx = drr::DrrCtx; + + using DrrNativeIrValue = drr::NativeIrValue; + using DrrPackedIrValue = drr::PackedIrValue; + using DrrIrValue = drr::IrValue; + + using DrrNativeIrOp = drr::NativeIrOp; + using DrrNativeIrOpOperand = drr::NativeIrOpOperand; + using DrrNativeIrOpResult = drr::NativeIrOpResult; + using DrrPackedIrOp = drr::PackedIrOp; + using DrrPackedIrOpOperand = drr::PackedIrOpOperand; + using DrrPackedIrOpResult = drr::PackedIrOpResult; + using DrrOptPackedIrOp = drr::OptPackedIrOp; + using DrrOptPackedIrOpOperand = drr::OptPackedIrOpOperand; + using DrrOptPackedIrOpResult = drr::OptPackedIrOpResult; + + using DrrIrOpImpl = std::variant; + + using IrMatchCtx = ir_match::IrMatchCtx; + + using GraphMatchCtx = ir_match::GraphMatchCtx; + + const GraphMatchCtx& match_ctx_; + const DrrCtx& drr_ctx_; + + template + adt::Result VisitMatchedBirInputOfRestPtnPackedIrOp( + const DrrPackedIrOp& res_ptn_ir_op, const DoEachT& DoEach) const { + auto CollectInput = + [&](const DrrIrValue& drr_ir_value) -> adt::Result { + return VisitMatchedBirValueOfResPtnIrValue(drr_ir_value, DoEach); + }; + ADT_RETURN_IF_ERR( + VisitResPtnInputIrValueByResPtnIrOp(res_ptn_ir_op, CollectInput)); + return adt::Ok{}; + } + + template + adt::Result VisitMatchedBirOutputOfRestPtnPackedIrOp( + const DrrPackedIrOp& res_ptn_ir_op, const DoEachT& DoEach) const { + auto DoEachDrrIrValue = + [&](const DrrIrValue& drr_ir_value) -> adt::Result { + return VisitMatchedBirValueOfResPtnIrValue(drr_ir_value, DoEach); + }; + ADT_RETURN_IF_ERR( + VisitResPtnOutputIrValueByResPtnIrOp(res_ptn_ir_op, DoEachDrrIrValue)); + return adt::Ok{}; + } + + template + adt::Result VisitResPtnInputIrValueByResPtnIrOp( + const DrrPackedIrOp& res_ptn_ir_op, const DoEachT& DoEach) const { + drr::ResultPatternHelper helper{drr_ctx_}; + return helper.VisitResPtnInputIrValueByResPtnIrOp(res_ptn_ir_op, DoEach); + } + + template + adt::Result VisitMatchedBirValueOfResPtnIrValue( + const DrrIrValue& res_ptn_ir_value, const DoEachT& DoEach) const { + using Ok = adt::Result; + const auto& opt_ir_value = SrcPtnIrValue4ResPtnIrValue(res_ptn_ir_value); + ADT_CHECK(opt_ir_value.has_value()); + const auto& ir_value = opt_ir_value.value(); + ADT_RETURN_IF_ERR( + match_ctx_->VisitBigGraphIrValueNode(ir_value.node(), DoEach)); + return adt::Ok{}; + } + + template + adt::Result VisitApKernelInputIndexOrSlice( + const DrrPackedIrOp& res_ptn_ir_op, + const DoEachIndexT& DoEachIndex, + const DoEachSliceT& DoEachSlice) const { + std::size_t start = 0; + using Ok = adt::Result; + auto DoEachIrValue = [&](const DrrIrValue& drr_ir_value) -> Ok { + ADT_LET_CONST_REF(num_ir_values, GetResPtnNumBirValues(drr_ir_value)); + ADT_RETURN_IF_ERR(drr_ir_value.Match( + [&](const DrrNativeIrValue&) -> Ok { + ADT_RETURN_IF_ERR(DoEachIndex(start)); + return adt::Ok{}; + }, + [&](const DrrPackedIrValue&) -> Ok { + ADT_RETURN_IF_ERR(DoEachSlice(start, start + num_ir_values)); + return adt::Ok{}; + })); + start += num_ir_values; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + VisitResPtnInputIrValueByResPtnIrOp(res_ptn_ir_op, DoEachIrValue)); + return adt::Ok{}; + } + + template + adt::Result VisitApKernelOutputIndexOrSlice( + const DrrPackedIrOp& res_ptn_ir_op, + const DoEachIndexT& DoEachIndex, + const DoEachSliceT& DoEachSlice) const { + std::size_t start = 0; + using Ok = adt::Result; + auto DoEachIrValue = [&](const DrrIrValue& drr_ir_value) -> Ok { + ADT_LET_CONST_REF(num_ir_values, GetResPtnNumBirValues(drr_ir_value)); + ADT_RETURN_IF_ERR(drr_ir_value.Match( + [&](const DrrNativeIrValue&) -> Ok { + ADT_RETURN_IF_ERR(DoEachIndex(start)); + return adt::Ok{}; + }, + [&](const DrrPackedIrValue&) -> Ok { + ADT_RETURN_IF_ERR(DoEachSlice(start, start + num_ir_values)); + return adt::Ok{}; + })); + start += num_ir_values; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + VisitResPtnOutputIrValueByResPtnIrOp(res_ptn_ir_op, DoEachIrValue)); + return adt::Ok{}; + } + + adt::Result GetApKernelNumOutputs( + const DrrPackedIrOp& res_ptn_ir_op) const { + std::size_t num_outputs = 0; + auto AccNumOutputs = + [&](const DrrIrValue& drr_ir_value) -> adt::Result { + ADT_LET_CONST_REF(num_ir_values, GetResPtnNumBirValues(drr_ir_value)); + num_outputs += num_ir_values; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + VisitResPtnOutputIrValueByResPtnIrOp(res_ptn_ir_op, AccNumOutputs)); + return num_outputs; + } + + template + adt::Result VisitEachMatchedDrrIrValueAndOutputSlice( + const std::vector& output_values, + const IrOpT& res_ptn_ir_op, + const DoEachT& DoEach) const { + std::size_t offset = 0; + auto DoEachSlice = + [&](const DrrIrValue& drr_ir_value) -> adt::Result { + ADT_LET_CONST_REF(num_ir_values, GetResPtnNumBirValues(drr_ir_value)); + ADT_CHECK(offset + num_ir_values <= output_values.size()); + std::vector slice{output_values.begin() + offset, + output_values.begin() + offset + num_ir_values}; + return DoEach(drr_ir_value, slice); + }; + return VisitResPtnOutputIrValueByResPtnIrOp(res_ptn_ir_op, DoEachSlice); + } + + template + adt::Result VisitResPtnOutputIrValueByResPtnIrOp( + const IrOpT& res_ptn_ir_op, const DoEachT& DoEach) const { + drr::ResultPatternHelper helper{drr_ctx_}; + return helper.VisitResPtnOutputIrValueByResPtnIrOp(res_ptn_ir_op, DoEach); + } + + adt::Result GetResPtnNumBirValues( + const DrrIrValue& res_ptn_ir_value) const { + const auto& opt_src_ptn_ir_value = + SrcPtnIrValue4ResPtnIrValue(res_ptn_ir_value); + if (!opt_src_ptn_ir_value.has_value()) { + // internal ir value in result pattern. + return 1; + } + return match_ctx_->GetNumBigGraphIrValueNodes( + opt_src_ptn_ir_value.value().node()); + } + + std::optional SrcPtnIrValue4ResPtnIrValue( + const DrrIrValue& res_ptn_ir_value) const { + drr::ResultPatternHelper helper{drr_ctx_}; + return helper.SrcPtnIrValue4ResPtnIrValue(res_ptn_ir_value); + } + + using BirNativeIrValue = typename BirNode::native_value_type; + + adt::Result CastToBirNativeIrValue( + const BirNode& bir_node) const { + using RetT = adt::Result; + return bir_node.Match( + [&](const typename BirNode::native_value_type& bir_value) -> RetT { + return bir_value; + }, + [&](const typename BirNode::ref_value_type& ref_value) -> RetT { + return ref_value.GetOwnerNativeIrValue(); + }, + [&](const auto&) -> RetT { + return adt::errors::TypeError{ + "bir_node is not an PirNode::native_value_type or " + "BirNode::ref_value_type"}; + }); + } +}; + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/op_code_gen_ctx.h b/paddle/ap/include/code_gen/op_code_gen_ctx.h new file mode 100644 index 00000000000000..8dbbd8e941d2ac --- /dev/null +++ b/paddle/ap/include/code_gen/op_code_gen_ctx.h @@ -0,0 +1,44 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/code_gen/kernel_arg_id.h" +#include "paddle/ap/include/code_gen/loop_anchor_flags.h" +#include "paddle/ap/include/index_expr/index_tuple_expr.h" +#include "paddle/ap/include/ir_match/native_or_ref_ir_value.h" + +namespace ap::code_gen { + +template +struct CodeGenCtxImpl; + +template +struct OpCodeGenCtxImpl { + std::weak_ptr> code_gen_ctx; + + LoopAnchorFlags input_index_loop_anchor_flags; + LoopAnchorFlags output_index_loop_anchor_flags; + + bool operator==(const OpCodeGenCtxImpl& other) const { + return this == &other; + } +}; + +template +ADT_DEFINE_RC(OpCodeGenCtx, OpCodeGenCtxImpl); + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/op_cuda_gen_impl.h b/paddle/ap/include/code_gen/op_cuda_gen_impl.h new file mode 100644 index 00000000000000..35ea5be08c9b74 --- /dev/null +++ b/paddle/ap/include/code_gen/op_cuda_gen_impl.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/class_attrs.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/code_gen/ir_op.h" +#include "paddle/ap/include/code_gen/op_code_gen_ctx.h" + +namespace ap::code_gen { + +template +struct OpCudaCodeGenImpl { + adt::Result CodeGen(const OpCodeGenCtx& op_code_gen_ctx, + const IrOp& ir_op); + adt::Result> + ConvertFusionOpToClassAttrs(const OpCodeGenCtx& op_code_gen_ctx, + const IrOp& ir_op); +}; + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id.h b/paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id.h new file mode 100644 index 00000000000000..fca622f0ec4983 --- /dev/null +++ b/paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id.h @@ -0,0 +1,67 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::code_gen { + +template +struct OutTensorDataPtrKernelArgIdImpl { + BirNode ir_value; + std::optional> runtime_getter; + + bool operator==(const OutTensorDataPtrKernelArgIdImpl& other) const { + return this->ir_value == other.ir_value; + } + + template + adt::Result CastData() const { + using RetT = adt::Result; + return this->ir_value.Match( + [&](const typename BirNode::native_value_type& impl) -> RetT { + return impl; + }, + [&](const typename BirNode::ref_value_type& impl) -> RetT { + return impl; + }, + [&](const auto& impl) -> RetT { + using T = std::decay_t; + return adt::errors::NotImplementedError{ + std::string() + + "CastData() failed, only NativeIrValue and RefIrValue supported, " + "but '" + + typeid(T).name() + "' found."}; + }); + } + + std::size_t GetHashValue() const { + return std::hash()(this->ir_value); + } +}; + +template +ADT_DEFINE_RC(OutTensorDataPtrKernelArgId, + OutTensorDataPtrKernelArgIdImpl); + +template +axpr::TypeImpl> +GetOutTensorDataPtrKernelArgIdClass(); + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id_method_class.h b/paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id_method_class.h new file mode 100644 index 00000000000000..e469f04cb03694 --- /dev/null +++ b/paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id_method_class.h @@ -0,0 +1,86 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id.h" +#include "paddle/ap/include/code_gen/kernel_arg_id_helper.h" + +namespace ap::code_gen { + +template +struct OutTensorDataPtrKernelArgIdMethodClass { + using This = OutTensorDataPtrKernelArgIdMethodClass; + using Self = OutTensorDataPtrKernelArgId; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << ""; + return ss.str(); + } + + static adt::Result Hash(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::size_t hash_value = self->ir_value.GetHashValue(); + return static_cast(hash_value); + } + + static adt::Result GetAttr(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + const auto& attr_name_val = args.at(0); + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + if (attr_name == "type") { + return This{}.GetArgType(self); + } + if (attr_name == "runtime_getter") { + ADT_CHECK(self->runtime_getter.has_value()) + << adt::errors::ValueError{"no runtime getter initialized"}; + return self->runtime_getter.value(); + } + return adt::errors::AttributeError{ + std::string() + + "'OutTensorDataPtrKernelArgId' instance has no attribute '" + + attr_name + "'."}; + } + + adt::Result GetArgType(const Self& self) { + KernelArgIdHelper helper; + ADT_LET_CONST_REF(arg_type, helper.GetArgType(self)); + return arg_type.template CastTo(); + } +}; + +template +axpr::TypeImpl> +GetOutTensorDataPtrKernelArgIdClass() { + using ImplMethods = OutTensorDataPtrKernelArgIdMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "OutTensorDataPtrKernelArgId", [&](const auto& Define) { + Define("__str__", &ImplMethods::ToString); + Define("__hash__", &ImplMethods::Hash); + Define("__getattr__", &ImplMethods::GetAttr); + })); + using Self = typename ImplMethods::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/value.h b/paddle/ap/include/code_gen/value.h new file mode 100644 index 00000000000000..835641eb8e9fdf --- /dev/null +++ b/paddle/ap/include/code_gen/value.h @@ -0,0 +1,37 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_serializable_attr_map.h" +#include "paddle/ap/include/axpr/dim_expr.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_gen/code_gen_ctx.h" +#include "paddle/ap/include/code_gen/code_gen_result.h" +#include "paddle/ap/include/code_gen/dim_expr_kernel_arg_id.h" +#include "paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id.h" +#include "paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id.h" +#include "paddle/ap/include/code_module/adt.h" +#include "paddle/ap/include/code_module/code_module.h" +#include "paddle/ap/include/code_module/data_type.h" +#include "paddle/ap/include/index_expr/index_expr.h" +#include "paddle/ap/include/index_expr/index_tuple_expr.h" +#include "paddle/ap/include/ir_match/op_match_ctx.h" +#include "paddle/ap/include/ir_match/tensor_match_ctx.h" + +namespace ap::code_gen { + +using axpr::Value; + +} // namespace ap::code_gen diff --git a/paddle/ap/include/code_gen/value_method_class.h b/paddle/ap/include/code_gen/value_method_class.h new file mode 100644 index 00000000000000..9d6d33cfdf8bba --- /dev/null +++ b/paddle/ap/include/code_gen/value_method_class.h @@ -0,0 +1,27 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/dim_expr_method_class.h" +#include "paddle/ap/include/code_gen/code_gen_ctx_method_class.h" +#include "paddle/ap/include/code_gen/code_gen_result_method_class.h" +#include "paddle/ap/include/code_gen/dim_expr_kernel_arg_id_method_class.h" +#include "paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id_method_class.h" +#include "paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id_method_class.h" +#include "paddle/ap/include/index_expr/index_expr_method_class.h" +#include "paddle/ap/include/index_expr/index_tuple_expr_method_class.h" +#include "paddle/ap/include/index_expr/slice_method_class.h" +#include "paddle/ap/include/ir_match/op_match_ctx_method_class.h" +#include "paddle/ap/include/ir_match/tensor_match_ctx_method_class.h" diff --git a/paddle/ap/include/code_module/adt.h b/paddle/ap/include/code_module/adt.h new file mode 100644 index 00000000000000..f030c87cbb2304 --- /dev/null +++ b/paddle/ap/include/code_module/adt.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" + +namespace ap { + +using adt::errors::AttributeError; +using adt::errors::Error; +using adt::errors::IndexError; +using adt::errors::InvalidArgumentError; +using adt::errors::NameError; +using adt::errors::RuntimeError; +using adt::errors::SyntaxError; +using adt::errors::TypeError; +using adt::errors::ValueError; + +template +using Result = adt::Result; +} // namespace ap diff --git a/paddle/ap/include/code_module/api_wrapper_project_maker.h b/paddle/ap/include/code_module/api_wrapper_project_maker.h new file mode 100644 index 00000000000000..0e3b0ce24ea165 --- /dev/null +++ b/paddle/ap/include/code_module/api_wrapper_project_maker.h @@ -0,0 +1,254 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/code_module/func_declare.h" +#include "paddle/ap/include/code_module/project.h" + +namespace ap::code_module { + +struct ApiWrapperProjectMaker { + adt::Result Make(const std::vector& func_declares) { + ADT_LET_CONST_REF(nested_files, MakeNestedFiles(func_declares)); + ADT_LET_CONST_REF(compile_cmd, GetCompileCmd()); + ADT_LET_CONST_REF(so_relative_path, GetSoRelativePath()); + axpr::AttrMap others; + return Project{nested_files, compile_cmd, so_relative_path, others}; + } + + adt::Result> MakeNestedFiles( + const std::vector& func_declares) { + Directory directory; + ADT_LET_CONST_REF(file_content, + GenerateApiWrapperCFileContent(func_declares)); + directory.dentry2file->Set("api_wrapper.c", FileContent{file_content}); + return directory; + } + + adt::Result GenerateApiWrapperCFileContent( + const std::vector& func_declares) { + std::ostringstream ss; + ss << "#include " << std::endl << std::endl; + for (const auto& func_declare : func_declares) { + ADT_RETURN_IF_ERR(GenerateCCode4FuncDeclare(&ss, func_declare)); + ss << std::endl; + } + return ss.str(); + } + + adt::Result GenerateCCode4FuncDeclare( + std::ostringstream* ss, const FuncDeclare& func_declare) { + (*ss) << "void " << func_declare->func_id + << "(void* ret, void* f, void** args) {" << std::endl; + ADT_LET_CONST_REF(func_ptr_var, DeclareFuncPtrType(func_declare, "func")); + (*ss) << " " << func_ptr_var << " = f" + << ";\n"; + ADT_LET_CONST_REF(func_call_str, + GenerateFuncCall(func_declare, "func", "args")); + if (IsVoidRet(func_declare)) { + (*ss) << " " << func_call_str << ";\n"; + } else { + ADT_LET_CONST_REF(ret_type, GenCode4ArgType(func_declare->ret_type)); + (*ss) << " *(" << ret_type << "*)ret = " << func_call_str << ";\n"; + } + (*ss) << "}\n" << std::endl; + return adt::Ok{}; + } + + bool IsVoidRet(const FuncDeclare& func_declare) { + return func_declare->ret_type.Match( + [&](const axpr::DataType& data_type) -> bool { + return data_type.template Has>(); + }, + [&](const axpr::PointerType& pointer_type) -> bool { return false; }); + } + + adt::Result GenerateFuncCall(const FuncDeclare& func_declare, + const std::string& func_var_name, + const std::string& args_var_name) { + std::ostringstream ss; + ss << func_var_name << "("; + for (int i = 0; i < func_declare->arg_types->size(); ++i) { + if (i > 0) { + ss << ", "; + } + const auto& arg_type = func_declare->arg_types->at(i); + ADT_LET_CONST_REF(arg_type_str, GenCode4ArgType(arg_type)); + ss << "*(" << arg_type_str << "*)" << args_var_name << "[" << i << "]"; + } + ss << ")"; + return ss.str(); + } + + adt::Result DeclareFuncPtrType(const FuncDeclare& func_declare, + const std::string& func_name) { + std::ostringstream ss; + ADT_LET_CONST_REF(ret_type, GenCode4ArgType(func_declare->ret_type)); + ss << ret_type << "(*" << func_name << ")("; + int i = 0; + for (const auto& arg_type : *func_declare->arg_types) { + if (i++ > 0) { + ss << ", "; + } + ADT_LET_CONST_REF(arg_type_str, GenCode4ArgType(arg_type)); + ss << arg_type_str; + } + ss << ")"; + return ss.str(); + } + + adt::Result GenCode4ArgType(const ArgType& arg_type) { + return arg_type.Match( + [&](const axpr::DataType& data_type) -> adt::Result { + return GenCode4DataType(data_type); + }, + [&](const axpr::PointerType& pointer_type) -> adt::Result { + return GenCode4PointerType(pointer_type); + }); + } + + adt::Result GenCode4DataType(const axpr::DataType& data_type) { + using RetT = adt::Result; + return data_type.Match( + [&](axpr::CppDataType) -> RetT { return "bool"; }, + [&](axpr::CppDataType) -> RetT { return "int8_t"; }, + [&](axpr::CppDataType) -> RetT { return "uint8_t"; }, + [&](axpr::CppDataType) -> RetT { return "int16_t"; }, + [&](axpr::CppDataType) -> RetT { return "uint16_t"; }, + [&](axpr::CppDataType) -> RetT { return "int32_t"; }, + [&](axpr::CppDataType) -> RetT { return "uint32_t"; }, + [&](axpr::CppDataType) -> RetT { return "int64_t"; }, + [&](axpr::CppDataType) -> RetT { return "uint64_t"; }, + [&](axpr::CppDataType) -> RetT { return "float"; }, + [&](axpr::CppDataType) -> RetT { return "double"; }, + [&](axpr::CppDataType) -> RetT { + return adt::errors::TypeError{ + "bfloat16 are not allowed being used by so function"}; + }, + [&](axpr::CppDataType) -> RetT { + return adt::errors::TypeError{ + "float8_e4m3fn are not allowed being used by so function"}; + }, + [&](axpr::CppDataType) -> RetT { + return adt::errors::TypeError{ + "float8_e5m2 are not allowed being used by so function"}; + }, + [&](axpr::CppDataType) -> RetT { + return adt::errors::TypeError{ + "float16 are not allowed being used by so function"}; + }, + [&](axpr::CppDataType) -> RetT { + return adt::errors::TypeError{ + "complex64 are not allowed being used by so function"}; + }, + [&](axpr::CppDataType) -> RetT { + return adt::errors::TypeError{ + "complex128 are not allowed being used by so function"}; + }, + [&](axpr::CppDataType) -> RetT { + return adt::errors::TypeError{ + "pstring are not allowed being used by so function"}; + }, + [&](axpr::CppDataType) -> RetT { return "void"; }); + } + + adt::Result GenCode4PointerType( + const axpr::PointerType& pointer_type) { + using RetT = adt::Result; + return pointer_type.Match( + [&](axpr::CppPointerType) -> RetT { return "bool*"; }, + [&](axpr::CppPointerType) -> RetT { return "int8_t*"; }, + [&](axpr::CppPointerType) -> RetT { return "uint8_t*"; }, + [&](axpr::CppPointerType) -> RetT { return "int16_t*"; }, + [&](axpr::CppPointerType) -> RetT { return "uint16_t*"; }, + [&](axpr::CppPointerType) -> RetT { return "int32_t*"; }, + [&](axpr::CppPointerType) -> RetT { return "uint32_t*"; }, + [&](axpr::CppPointerType) -> RetT { return "int64_t*"; }, + [&](axpr::CppPointerType) -> RetT { return "uint64_t*"; }, + [&](axpr::CppPointerType) -> RetT { return "float*"; }, + [&](axpr::CppPointerType) -> RetT { return "double*"; }, + [&](axpr::CppPointerType) -> RetT { return "void*"; }, + [&](axpr::CppPointerType) -> RetT { + return "void*"; + }, + [&](axpr::CppPointerType) -> RetT { + return "void*"; + }, + [&](axpr::CppPointerType) -> RetT { return "void*"; }, + [&](axpr::CppPointerType) -> RetT { return "void*"; }, + [&](axpr::CppPointerType) -> RetT { + return "void*"; + }, + [&](axpr::CppPointerType) -> RetT { return "void*"; }, + [&](axpr::CppPointerType) -> RetT { return "void*"; }, + [&](axpr::CppPointerType) -> RetT { return "bool*"; }, + [&](axpr::CppPointerType) -> RetT { return "int8_t*"; }, + [&](axpr::CppPointerType) -> RetT { + return "uint8_t*"; + }, + [&](axpr::CppPointerType) -> RetT { + return "int16_t*"; + }, + [&](axpr::CppPointerType) -> RetT { + return "uint16_t*"; + }, + [&](axpr::CppPointerType) -> RetT { + return "int32_t*"; + }, + [&](axpr::CppPointerType) -> RetT { + return "uint32_t*"; + }, + [&](axpr::CppPointerType) -> RetT { + return "int64_t*"; + }, + [&](axpr::CppPointerType) -> RetT { + return "uint64_t*"; + }, + [&](axpr::CppPointerType) -> RetT { return "float*"; }, + [&](axpr::CppPointerType) -> RetT { return "double*"; }, + [&](axpr::CppPointerType) -> RetT { + return "void*"; + }, + [&](axpr::CppPointerType) -> RetT { + return "void*"; + }, + [&](axpr::CppPointerType) -> RetT { + return "void*"; + }, + [&](axpr::CppPointerType) -> RetT { + return "void*"; + }, + [&](axpr::CppPointerType) -> RetT { + return "void*"; + }, + [&](axpr::CppPointerType) -> RetT { + return "void*"; + }, + [&](axpr::CppPointerType) -> RetT { + return "void*"; + }, + [&](axpr::CppPointerType) -> RetT { return "void*"; }); + } + + adt::Result GetCompileCmd() { + return "gcc -fPIC -shared api_wrapper.c -o api_wrapper.so"; + } + + adt::Result GetSoRelativePath() { return "api_wrapper.so"; } +}; + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/arg_type.h b/paddle/ap/include/code_module/arg_type.h new file mode 100644 index 00000000000000..0b1cd05a5c8bff --- /dev/null +++ b/paddle/ap/include/code_module/arg_type.h @@ -0,0 +1,82 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/axpr/value_method_class.h" +#include "paddle/ap/include/code_module/adt.h" +#include "paddle/ap/include/code_module/data_type.h" + +namespace phi { + +class DenseTensor; + +} + +namespace ap::code_module { + +using ap::axpr::DataType; +using ap::axpr::MethodClass; +using ap::axpr::PointerType; + +using ArgTypeImpl = std::variant; + +struct ArgType : public ArgTypeImpl { + using ArgTypeImpl::ArgTypeImpl; + ADT_DEFINE_VARIANT_METHODS(ArgTypeImpl); + + const char* Name() const { + return Match([](const auto& impl) { return impl.Name(); }); + } + + template + bool IsType() const { + if constexpr (std::is_pointer_v) { + const auto& pointer_type = this->template TryGet(); + if (pointer_type.HasError()) { + return false; + } + return pointer_type.GetOkValue() + .template Has>(); + } else { + const auto& data_type = this->template TryGet(); + if (data_type.HasError()) { + return false; + } + return data_type.GetOkValue().template Has>(); + } + } + + template + adt::Result CastTo() const { + return Match([](const auto& impl) -> adt::Result { return impl; }); + } +}; + +template +Result CastToArgType(const ValueT& val) { + return val.Match( + [&](const DataType& atype) -> Result { return ArgType{atype}; }, + [&](const PointerType& ptype) -> Result { + return ArgType{ptype}; + }, + [&](const auto&) -> Result { + return adt::errors::TypeError{std::string() + + "CastToArgType failed. expected types: " + "(DataType, PointerType), actual type: " + + axpr::GetTypeName(val)}; + }); +} + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/builtin_frame_util.h b/paddle/ap/include/code_module/builtin_frame_util.h new file mode 100644 index 00000000000000..49be3b44b27054 --- /dev/null +++ b/paddle/ap/include/code_module/builtin_frame_util.h @@ -0,0 +1,51 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_frame_util.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_module/code_module_method_class.h" +#include "paddle/ap/include/code_module/directory_method_class.h" +#include "paddle/ap/include/code_module/file_content_method_class.h" +#include "paddle/ap/include/code_module/func_declare_method_class.h" +#include "paddle/ap/include/code_module/package_method_class.h" +#include "paddle/ap/include/code_module/project_method_class.h" +#include "paddle/ap/include/code_module/soft_link_method_class.h" + +namespace ap::code_module { + +template +void VisitEachBuiltinFrameAttr(const DoEachT& DoEach) { + DoEach(GetFileContentClass()); + DoEach(GetSoftLinkClass()); + DoEach(GetDirectoryClass()); + DoEach(GetProjectClass()); + DoEach(GetPackageClass()); + DoEach(GetFuncDeclareClass()); + DoEach(GetCodeModuleClass()); +} + +template +axpr::AttrMap MakeBuiltinFrameAttrMap() { + axpr::AttrMap attr_map; + axpr::VisitEachBuiltinFrameAttr( + [&](const std::string& k, const ValueT& v) { attr_map->Set(k, v); }); + VisitEachBuiltinFrameAttr( + [&](const auto& cls) { attr_map->Set(cls.Name(), cls); }); + return attr_map; +} + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/code_module.h b/paddle/ap/include/code_module/code_module.h new file mode 100644 index 00000000000000..1b5445a461b475 --- /dev/null +++ b/paddle/ap/include/code_module/code_module.h @@ -0,0 +1,37 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/code_module/adt.h" +#include "paddle/ap/include/code_module/arg_type.h" +#include "paddle/ap/include/code_module/data_type.h" +#include "paddle/ap/include/code_module/func_declare.h" +#include "paddle/ap/include/code_module/source_code.h" + +namespace ap::code_module { + +struct CodeModuleImpl { + adt::List func_declares; + SourceCode source_code; + + bool operator==(const CodeModuleImpl& other) const { + return other.func_declares == this->func_declares && + other.source_code == this->source_code; + } +}; +ADT_DEFINE_RC(CodeModule, CodeModuleImpl); + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/code_module_method_class.h b/paddle/ap/include/code_module/code_module_method_class.h new file mode 100644 index 00000000000000..6427d3369b3234 --- /dev/null +++ b/paddle/ap/include/code_module/code_module_method_class.h @@ -0,0 +1,28 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_module/code_module.h" +#include "paddle/ap/include/code_module/func_declare.h" +#include "paddle/ap/include/code_module/source_code.h" + +namespace ap::code_module { + +axpr::TypeImpl> GetCodeModuleClass(); + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/data_type.h b/paddle/ap/include/code_module/data_type.h new file mode 100644 index 00000000000000..9d4eb578ec218f --- /dev/null +++ b/paddle/ap/include/code_module/data_type.h @@ -0,0 +1,39 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/pstring.h" + +namespace ap { + +using complex64 = ::phi::dtype::complex; +using complex128 = ::phi::dtype::complex; +using float16 = ::phi::dtype::float16; +using bfloat16 = ::phi::dtype::bfloat16; +using float8_e4m3fn = ::phi::dtype::float8_e4m3fn; +using float8_e5m2 = ::phi::dtype::float8_e5m2; +using pstring = ::phi::dtype::pstring; + +#define AP_FOR_EACH_INT_TYPE(_) \ + _(int8) \ + _(uint8) \ + _(int16) \ + _(uint16) \ + _(int32) \ + _(uint32) \ + _(int64) \ + _(uint64) + +} // namespace ap diff --git a/paddle/ap/include/code_module/directory.h b/paddle/ap/include/code_module/directory.h new file mode 100644 index 00000000000000..254a7243050a0a --- /dev/null +++ b/paddle/ap/include/code_module/directory.h @@ -0,0 +1,31 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::code_module { + +template +struct Directory { + axpr::AttrMap dentry2file; + + bool operator==(const Directory& other) const { + return this->dentry2file == other.dentry2file; + } +}; + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/directory_method_class.h b/paddle/ap/include/code_module/directory_method_class.h new file mode 100644 index 00000000000000..c0d318d15be63f --- /dev/null +++ b/paddle/ap/include/code_module/directory_method_class.h @@ -0,0 +1,27 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_module/directory.h" +#include "paddle/ap/include/code_module/file.h" + +namespace ap::code_module { + +axpr::TypeImpl> GetDirectoryClass(); + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/file.h b/paddle/ap/include/code_module/file.h new file mode 100644 index 00000000000000..da889d8a72a217 --- /dev/null +++ b/paddle/ap/include/code_module/file.h @@ -0,0 +1,48 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/code_module/directory.h" +#include "paddle/ap/include/code_module/file_content.h" +#include "paddle/ap/include/code_module/soft_link.h" + +namespace ap::code_module { + +template +using FileImpl = std::variant>; + +struct File : public FileImpl { + using FileImpl::FileImpl; + ADT_DEFINE_VARIANT_METHODS(FileImpl); + + static adt::Result CastFromAxprValue(const axpr::Value& val) { + if (val.template CastableTo()) { + ADT_LET_CONST_REF(file_content, val.template CastTo()); + return file_content; + } + if (val.template CastableTo()) { + ADT_LET_CONST_REF(soft_link, val.template CastTo()); + return soft_link; + } + if (val.template CastableTo>()) { + ADT_LET_CONST_REF(directory, val.template CastTo>()); + return directory; + } + return adt::errors::TypeError{"File::CastFromAxprValue() failed."}; + } +}; + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/file_content.h b/paddle/ap/include/code_module/file_content.h new file mode 100644 index 00000000000000..33e2e4f240f674 --- /dev/null +++ b/paddle/ap/include/code_module/file_content.h @@ -0,0 +1,30 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::code_module { + +struct FileContentImpl { + std::string file_content; + + bool operator==(const FileContentImpl& other) const { + return this->file_content == other.file_content; + } +}; +ADT_DEFINE_RC(FileContent, FileContentImpl); + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/file_content_method_class.h b/paddle/ap/include/code_module/file_content_method_class.h new file mode 100644 index 00000000000000..654eacc7c5028a --- /dev/null +++ b/paddle/ap/include/code_module/file_content_method_class.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_module/file_content.h" + +namespace ap::code_module { + +axpr::TypeImpl> GetFileContentClass(); + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/func_declare.h b/paddle/ap/include/code_module/func_declare.h new file mode 100644 index 00000000000000..40d52b5bc56108 --- /dev/null +++ b/paddle/ap/include/code_module/func_declare.h @@ -0,0 +1,38 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/code_module/adt.h" +#include "paddle/ap/include/code_module/arg_type.h" +#include "paddle/ap/include/code_module/data_type.h" + +namespace ap::code_module { + +using FuncId = std::string; + +struct FuncDeclareImpl { + ArgType ret_type; + FuncId func_id; + adt::List arg_types; + + bool operator==(const FuncDeclareImpl& other) const { + return other.func_id == this->func_id && other.ret_type == this->ret_type && + other.arg_types == this->arg_types; + } +}; +ADT_DEFINE_RC(FuncDeclare, FuncDeclareImpl); + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/func_declare_method_class.h b/paddle/ap/include/code_module/func_declare_method_class.h new file mode 100644 index 00000000000000..abfcd1debda746 --- /dev/null +++ b/paddle/ap/include/code_module/func_declare_method_class.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_module/func_declare.h" + +namespace ap::code_module { + +axpr::TypeImpl> GetFuncDeclareClass(); + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/module_compile_helper.h b/paddle/ap/include/code_module/module_compile_helper.h new file mode 100644 index 00000000000000..ba03d4a7655bb2 --- /dev/null +++ b/paddle/ap/include/code_module/module_compile_helper.h @@ -0,0 +1,86 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/code_module/api_wrapper_project_maker.h" +#include "paddle/ap/include/code_module/code_module.h" +#include "paddle/ap/include/code_module/package.h" +#include "paddle/ap/include/code_module/project_compile_helper.h" + +namespace ap::code_module { + +class ModuleCompileHelper { + std::string workspace_dir_; + std::string relative_dir_in_workspace_; + + public: + ModuleCompileHelper(const std::string& workspace_dir, + const std::string& relative_dir_in_workspace) + : workspace_dir_(workspace_dir), + relative_dir_in_workspace_(relative_dir_in_workspace) {} + + adt::Result CompileProjectModuleToPackageModule( + const CodeModule& project_module) const { + ADT_CHECK(project_module->source_code.template Has()); + const auto& func_declares = project_module->func_declares.vector(); + ADT_LET_CONST_REF( + api_wrapper_project, + code_module::ApiWrapperProjectMaker{}.Make(func_declares)); + ADT_LET_CONST_REF(main_project, GetMainProject(project_module)); + code_module::ProjectCompileHelper api_wrapper_compile_helper( + GetApiWrapperProjectAbsoluteDir(), api_wrapper_project); + code_module::ProjectCompileHelper main_compile_helper( + GetMainProjectAbsoluteDir(), main_project); + ADT_RETURN_IF_ERR(api_wrapper_compile_helper.DumpNestedFilesToFs()); + ADT_RETURN_IF_ERR(main_compile_helper.DumpNestedFilesToFs()); + ADT_RETURN_IF_ERR(api_wrapper_compile_helper.Compile()); + ADT_RETURN_IF_ERR(main_compile_helper.Compile()); + std::string api_wrapper_so_relative_path = + GetApiWrapperProjectRelativeDir() + "/" + + api_wrapper_project->so_relative_path; + std::string main_so_relative_path = + GetMainProjectRelativeDir() + "/" + main_project->so_relative_path; + Package ret_package{ + /*nested_files=*/Directory{}, + /*api_wrapper_so_relative_path=*/api_wrapper_so_relative_path, + /*main_so_relative_path=*/main_so_relative_path, + /*others=*/axpr::AttrMap{}}; + return CodeModule{project_module->func_declares, ret_package}; + } + + private: + adt::Result GetMainProject( + const code_module::CodeModule& code_module) const { + return code_module->source_code.template TryGet(); + } + + std::string GetApiWrapperProjectAbsoluteDir() const { + return workspace_dir_ + "/" + GetApiWrapperProjectRelativeDir(); + } + + std::string GetMainProjectAbsoluteDir() const { + return workspace_dir_ + "/" + GetMainProjectRelativeDir(); + } + + std::string GetApiWrapperProjectRelativeDir() const { + return relative_dir_in_workspace_ + "/api_wrapper/"; + } + std::string GetMainProjectRelativeDir() const { + return relative_dir_in_workspace_ + "/main/"; + } +}; + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/module_to_axpr_helper.h b/paddle/ap/include/code_module/module_to_axpr_helper.h new file mode 100644 index 00000000000000..662d5f7f4e95f2 --- /dev/null +++ b/paddle/ap/include/code_module/module_to_axpr_helper.h @@ -0,0 +1,154 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_serializable_attr_map_to_axpr_helper.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_module/code_module.h" + +namespace ap::code_module { + +struct ModuleToAxprHelper { + using AnfExpr = axpr::AnfExpr; + + adt::Result ConvertModuleToAnfExpr(axpr::LetContext* ctx, + const CodeModule& m) const { + return ConvertModuleToAnfExprImpl(ctx, m); + } + + adt::Result ConvertModuleToAnfExpr(const CodeModule& m) const { + auto ConstructLambdaBody = [&](auto& ctx) -> adt::Result { + return ConvertModuleToAnfExprImpl(&ctx, m); + }; + return ap::axpr::LambdaExprBuilder{}.TryLambda({}, ConstructLambdaBody); + } + + private: + adt::Result ConvertModuleToAnfExprImpl(axpr::LetContext* ctx, + const CodeModule& m) const { + auto ConvertArgType = [&](auto& ctx, const auto& arg_type) -> AnfExpr { + return arg_type.Match( + [&](const ap::axpr::DataType& data_type) -> AnfExpr { + const auto& var = ctx->Var("DataType").Attr(data_type.Name()); + return ap::axpr::tVar{var.name()}; + }, + [&](const ap::axpr::PointerType& pointer_type) -> AnfExpr { + const auto& var = ctx->Var("PointerType").Attr(pointer_type.Name()); + return ap::axpr::tVar{var.name()}; + }); + }; + auto ConvertFuncDeclareCall = [&](auto& ctx, + const auto& func_declare) -> AnfExpr { + const auto& ret_val_anf_expr = + ConvertArgType(ctx, func_declare->ret_type); + const auto& func_name = ctx->String(func_declare->func_id); + std::vector elts; + elts.reserve(func_declare->arg_types->size()); + for (const auto& arg_type : *func_declare->arg_types) { + elts.emplace_back(ConvertArgType(ctx, arg_type)); + } + const auto& arg_type_anf_expr = ctx->Call(ap::axpr::kBuiltinList(), elts); + return ctx->Call( + "FuncDeclare", ret_val_anf_expr, func_name, arg_type_anf_expr); + }; + auto ConvertFuncDeclareList = [&](auto& ctx) -> AnfExpr { + std::vector elts; + elts.reserve(m->func_declares->size()); + for (const auto& func_declare : *m->func_declares) { + elts.emplace_back(ConvertFuncDeclareCall(ctx, func_declare)); + } + return ctx->Call(ap::axpr::kBuiltinList(), elts); + }; + auto ConvertSourceCodeConstruction = + [&](auto* ctx) -> adt::Result { + return m->source_code.Match( + [&](const ap::code_module::Project& project) -> adt::Result { + return ConvertProjectConstruct(ctx, project); + }, + [&](const ap::code_module::Package& package) -> adt::Result { + return ConvertPackageConstruct(ctx, package); + }); + }; + const auto& declare = ConvertFuncDeclareList(ctx); + ADT_LET_CONST_REF(source_code, ConvertSourceCodeConstruction(ctx)); + return ctx->Call("CodeModule", declare, source_code); + } + + adt::Result ConvertProjectConstruct( + ap::axpr::LetContext* ctx, + const ap::code_module::Project& project) const { + const auto& attrs = project->others; + ADT_LET_CONST_REF(others_anf_expr, + GetCodeFromBuiltinSerializableAttrMap(ctx, attrs)); + std::map kwargs{ + {"nested_files", ConvertProjectNestedFiles(ctx, project->nested_files)}, + {"compile_cmd", AnfExpr{ctx->String(project->compile_cmd)}}, + {"so_relative_path", AnfExpr{ctx->String(project->so_relative_path)}}, + {"others", others_anf_expr}, + }; + return ctx->Apply("Project", {}, kwargs); + } + + adt::Result ConvertPackageConstruct( + ap::axpr::LetContext* ctx, + const ap::code_module::Package& package) const { + const auto& attrs = package->others; + ADT_LET_CONST_REF(others_anf_expr, + GetCodeFromBuiltinSerializableAttrMap(ctx, attrs)); + const auto& api_so_path = package->api_wrapper_so_relative_path; + const auto& main_so_path = package->main_so_relative_path; + std::map kwargs{ + {"nested_files", ConvertProjectNestedFiles(ctx, package->nested_files)}, + {"api_wrapper_so_relative_path", AnfExpr{ctx->String(api_so_path)}}, + {"main_so_relative_path", AnfExpr{ctx->String(main_so_path)}}, + {"others", others_anf_expr}, + }; + return ctx->Apply("Package", {}, kwargs); + } + + AnfExpr ConvertProjectNestedFiles(ap::axpr::LetContext* ctx, + const ap::code_module::File& file) const { + return file.Match( + [&](const ap::code_module::FileContent& file_content) -> AnfExpr { + const auto& str = file_content->file_content; + return ctx->Var("Project").Attr("FileContent").Call(ctx->String(str)); + }, + [&](const ap::code_module::SoftLink& soft_link) -> AnfExpr { + const auto& str = soft_link->target_relative_path; + return ctx->Var("Project").Attr("SoftLink").Call(ctx->String(str)); + }, + [&](const ap::code_module::Directory& dir) + -> AnfExpr { + std::vector args; + for (const auto& [k, v] : dir.dentry2file->storage) { + const auto& v_anf_expr = ConvertProjectNestedFiles(ctx, v); + args.emplace_back(ctx->Call( + ap::axpr::kBuiltinList(), ctx->String(k), v_anf_expr)); + } + return ctx->Apply(ctx->Var("Project").Attr("Directory"), args); + }); + } + + adt::Result GetCodeFromBuiltinSerializableAttrMap( + ap::axpr::LetContext* ctx, + const ap::axpr::AttrMap& attr_map) const { + return axpr::BuiltinSerializableAttrMapToAxprHelper{}.Convert(ctx, + attr_map); + } +}; + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/package.h b/paddle/ap/include/code_module/package.h new file mode 100644 index 00000000000000..ad141967e38da2 --- /dev/null +++ b/paddle/ap/include/code_module/package.h @@ -0,0 +1,37 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/code_module/adt.h" +#include "paddle/ap/include/code_module/arg_type.h" +#include "paddle/ap/include/code_module/data_type.h" +#include "paddle/ap/include/code_module/file.h" + +namespace ap::code_module { + +struct PackageImpl { + Directory nested_files; + std::string api_wrapper_so_relative_path; + std::string main_so_relative_path; + axpr::AttrMap others; + + bool operator==(const PackageImpl& other) const { return this == &other; } +}; +ADT_DEFINE_RC(Package, PackageImpl); + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/package_method_class.h b/paddle/ap/include/code_module/package_method_class.h new file mode 100644 index 00000000000000..d637179ee21ff7 --- /dev/null +++ b/paddle/ap/include/code_module/package_method_class.h @@ -0,0 +1,27 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_module/file.h" +#include "paddle/ap/include/code_module/package.h" + +namespace ap::code_module { + +axpr::TypeImpl> GetPackageClass(); + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/project.h b/paddle/ap/include/code_module/project.h new file mode 100644 index 00000000000000..42240156bed4d6 --- /dev/null +++ b/paddle/ap/include/code_module/project.h @@ -0,0 +1,37 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/code_module/adt.h" +#include "paddle/ap/include/code_module/arg_type.h" +#include "paddle/ap/include/code_module/data_type.h" +#include "paddle/ap/include/code_module/file.h" + +namespace ap::code_module { + +struct ProjectImpl { + Directory nested_files; + std::string compile_cmd; + std::string so_relative_path; + axpr::AttrMap others; + + bool operator==(const ProjectImpl& other) const { return this == &other; } +}; +ADT_DEFINE_RC(Project, ProjectImpl); + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/project_compile_helper.h b/paddle/ap/include/code_module/project_compile_helper.h new file mode 100644 index 00000000000000..c2d286f223e4b9 --- /dev/null +++ b/paddle/ap/include/code_module/project_compile_helper.h @@ -0,0 +1,126 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/code_module/project.h" +#include "paddle/ap/include/env/ap_path.h" + +namespace ap::code_module { + +struct ProjectCompileHelper { + ProjectCompileHelper(const std::string& workspace_dir_val, + const Project& project_val) + : workspace_dir(workspace_dir_val), project(project_val) {} + + adt::Result DumpNestedFilesToFs() { + return DumpNestedFilesToFs(this->project->nested_files, ""); + } + + adt::Result Compile() { + int ret_code = 0; + std::string change_dir_cmd = std::string() + "cd " + this->workspace_dir; + std::string compile_cmd = + change_dir_cmd + "; " + this->project->compile_cmd; + ret_code = WEXITSTATUS(std::system(compile_cmd.c_str())); + ADT_CHECK(ret_code == 0) << adt::errors::RuntimeError{ + std::string() + "system() failed. ret_code: " + + std::to_string(ret_code) + ", compile_cmd: " + compile_cmd}; + return adt::Ok{}; + } + + std::string GetSoPath() { + return this->workspace_dir + "/" + this->project->so_relative_path; + } + + private: + std::string workspace_dir; + Project project; + + adt::Result DumpNestedFilesToFs( + const Directory& directory, const std::string& relative_dir_path) { + std::string dir_path = this->workspace_dir + "/" + relative_dir_path; + std::string cmd = std::string() + "mkdir -p " + dir_path; + ADT_CHECK(WEXITSTATUS(std::system(cmd.c_str())) == 0); + using Ok = adt::Result; + for (const auto& [dentry, file] : directory.dentry2file->storage) { + ADT_RETURN_IF_ERR(file.Match( + [&](const FileContent& file_content) -> Ok { + return DumpFileContentToFs(file_content, + relative_dir_path + "/" + dentry); + }, + [&](const SoftLink& soft_link) -> Ok { + return DumpSoftLinkToFs(soft_link, + relative_dir_path + "/" + dentry); + }, + [&](const Directory& sub_dir) -> Ok { + return DumpNestedFilesToFs(sub_dir, + relative_dir_path + "/" + dentry); + })); + } + return adt::Ok{}; + } + + adt::Result DumpFileContentToFs( + const FileContent& file_content, const std::string& relative_file_path) { + std::string file_path = this->workspace_dir + "/" + relative_file_path; + std::ofstream of{file_path}; + ADT_CHECK(of.is_open()) << adt::errors::RuntimeError{ + std::string() + "file open failed. file_path: " + file_path}; + of << file_content->file_content; + of.close(); + return adt::Ok{}; + } + + adt::Result DumpSoftLinkToFs(const SoftLink& soft_link, + const std::string& relative_link_path) { + std::string link = this->workspace_dir + "/" + relative_link_path; + std::optional target_path; + auto FindExistedSourcePath = + [&](const auto& prefix) -> adt::Result { + std::string cur_target_path = + std::string() + prefix + "/" + soft_link->target_relative_path; + if (FileExists(cur_target_path)) { + target_path = cur_target_path; + return adt::Break{}; + } else { + return adt::Continue{}; + } + }; + ADT_RETURN_IF_ERR(env::VisitEachApPath(FindExistedSourcePath)); + ADT_CHECK(target_path.has_value()) << adt::errors::RuntimeError{ + std::string() + + "link failed. relative_path: " + soft_link->target_relative_path}; + std::string cmd = + std::string() + "ln -s " + target_path.value() + " " + link; + ADT_CHECK(WEXITSTATUS(std::system(cmd.c_str())) == 0); + return adt::Ok{}; + } + + bool FileExists(const std::string& filepath) { + std::fstream fp; + fp.open(filepath, std::fstream::in); + if (fp.is_open()) { + fp.close(); + return true; + } else { + return false; + } + } +}; + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/project_method_class.h b/paddle/ap/include/code_module/project_method_class.h new file mode 100644 index 00000000000000..dd0de8a703ae95 --- /dev/null +++ b/paddle/ap/include/code_module/project_method_class.h @@ -0,0 +1,31 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_module/directory.h" +#include "paddle/ap/include/code_module/directory_method_class.h" +#include "paddle/ap/include/code_module/file.h" +#include "paddle/ap/include/code_module/file_content_method_class.h" +#include "paddle/ap/include/code_module/project.h" +#include "paddle/ap/include/code_module/soft_link_method_class.h" + +namespace ap::code_module { + +axpr::TypeImpl> GetProjectClass(); + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/rt_module.h b/paddle/ap/include/code_module/rt_module.h new file mode 100644 index 00000000000000..0f31cb8e0693c9 --- /dev/null +++ b/paddle/ap/include/code_module/rt_module.h @@ -0,0 +1,27 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace ap::code_module { + +class RtModule { + public: + virtual ~RtModule() = default; + + protected: + RtModule() = default; +}; + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/soft_link.h b/paddle/ap/include/code_module/soft_link.h new file mode 100644 index 00000000000000..9af99c61d79d86 --- /dev/null +++ b/paddle/ap/include/code_module/soft_link.h @@ -0,0 +1,30 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::code_module { + +struct SoftLinkImpl { + std::string target_relative_path; + + bool operator==(const SoftLinkImpl& other) const { + return this->target_relative_path == other.target_relative_path; + } +}; +ADT_DEFINE_RC(SoftLink, SoftLinkImpl); + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/soft_link_method_class.h b/paddle/ap/include/code_module/soft_link_method_class.h new file mode 100644 index 00000000000000..605182d36caf73 --- /dev/null +++ b/paddle/ap/include/code_module/soft_link_method_class.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_module/soft_link.h" + +namespace ap::code_module { + +axpr::TypeImpl> GetSoftLinkClass(); + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/source_code.h b/paddle/ap/include/code_module/source_code.h new file mode 100644 index 00000000000000..8a54a1902615b3 --- /dev/null +++ b/paddle/ap/include/code_module/source_code.h @@ -0,0 +1,42 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/code_module/package.h" +#include "paddle/ap/include/code_module/project.h" + +namespace ap::code_module { + +using SourceCodeImpl = std::variant; + +struct SourceCode : public SourceCodeImpl { + using SourceCodeImpl::SourceCodeImpl; + ADT_DEFINE_VARIANT_METHODS(SourceCodeImpl); + + static adt::Result CastFromAxprValue(const axpr::Value& val) { + if (val.template CastableTo()) { + ADT_LET_CONST_REF(project, val.template CastTo()); + return project; + } + if (val.template CastableTo()) { + ADT_LET_CONST_REF(package, val.template CastTo()); + return package; + } + return adt::errors::TypeError{"SourceCode::CastFromAxprValue() failed"}; + } +}; + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/value.h b/paddle/ap/include/code_module/value.h new file mode 100644 index 00000000000000..317bad7e83356c --- /dev/null +++ b/paddle/ap/include/code_module/value.h @@ -0,0 +1,29 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/data_type.h" +#include "paddle/ap/include/axpr/pointer_type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_module/adt.h" +#include "paddle/ap/include/code_module/code_module.h" +#include "paddle/ap/include/code_module/data_type.h" +#include "paddle/ap/include/code_module/func_declare.h" + +namespace ap::code_module { + +using axpr::Value; + +} // namespace ap::code_module diff --git a/paddle/ap/include/code_module/value_method_class.h b/paddle/ap/include/code_module/value_method_class.h new file mode 100644 index 00000000000000..b8d50f712654b1 --- /dev/null +++ b/paddle/ap/include/code_module/value_method_class.h @@ -0,0 +1,19 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/data_type_method_class.h" +#include "paddle/ap/include/axpr/pointer_type_method_class.h" +#include "paddle/ap/include/axpr/value_method_class.h" diff --git a/paddle/ap/include/common/unique_id.h b/paddle/ap/include/common/unique_id.h new file mode 100644 index 00000000000000..294ff0f04c5ec7 --- /dev/null +++ b/paddle/ap/include/common/unique_id.h @@ -0,0 +1,27 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace ap::common { + +inline std::string NewUniqueId(const std::string& prefix) { + static std::atomic seq_no(0); + return prefix + std::to_string(seq_no++); +} + +} // namespace ap::common diff --git a/paddle/ap/include/drr/builtin_frame_util.h b/paddle/ap/include/drr/builtin_frame_util.h new file mode 100644 index 00000000000000..67b08abbf0f321 --- /dev/null +++ b/paddle/ap/include/drr/builtin_frame_util.h @@ -0,0 +1,42 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/builtin_frame_util.h" +#include "paddle/ap/include/drr/drr_ctx_method_class.h" +#include "paddle/ap/include/drr/drr_value.h" + +namespace ap::drr { + +template +void VisitEachBuiltinFrameClass(const DoEachT& DoEach) { + DoEach(drr::Type{}.GetClass()); +} + +template +ap::axpr::AttrMap MakeBuiltinFrameAttrMap( + const VisitorT& Visitor) { + ap::axpr::AttrMap attr_map; + ap::axpr::VisitEachBuiltinFrameAttr( + [&](const std::string& k, const axpr::Value& v) { attr_map->Set(k, v); }); + VisitEachBuiltinFrameClass( + [&](const auto& cls) { attr_map->Set(cls.Name(), cls); }); + Visitor([&](const auto& cls) { attr_map->Set(cls.Name(), cls); }); + return attr_map; +} + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/drr_ctx.h b/paddle/ap/include/drr/drr_ctx.h new file mode 100644 index 00000000000000..f7ae5a39539051 --- /dev/null +++ b/paddle/ap/include/drr/drr_ctx.h @@ -0,0 +1,64 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/drr_pass_type.h" +#include "paddle/ap/include/drr/result_pattern_ctx.h" +#include "paddle/ap/include/drr/source_pattern_ctx.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/type.h" +#include "paddle/ap/include/memory/circlable_ref_list_base.h" + +namespace ap::drr { + +struct DrrCtxImpl { + std::weak_ptr circlable_ref_list; + std::optional pass_name; + std::optional source_pattern_ctx; + std::optional result_pattern_ctx; + std::optional constraint_func; + std::optional drr_pass_type; + + adt::Result GetSourcePatternCtx() const { + ADT_CHECK(this->source_pattern_ctx.has_value()); + return this->source_pattern_ctx.value(); + } + + adt::Result GetResultPatternCtx() const { + ADT_CHECK(this->result_pattern_ctx.has_value()); + return this->result_pattern_ctx.value(); + } + + bool operator==(const DrrCtxImpl& other) const { return this == &other; } +}; + +ADT_DEFINE_RC(DrrCtx, DrrCtxImpl); + +axpr::TypeImpl> GetDrrCtxClass(); + +template <> +struct Type : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "DrrCtx"; } + + static axpr::TypeImpl> GetClass() { + return GetDrrCtxClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/drr_ctx_method_class.h b/paddle/ap/include/drr/drr_ctx_method_class.h new file mode 100644 index 00000000000000..97a05b15380119 --- /dev/null +++ b/paddle/ap/include/drr/drr_ctx_method_class.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/builtin_high_order_func_type.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_ctx.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/drr_value_helper.h" +#include "paddle/ap/include/drr/ir_op.h" +#include "paddle/ap/include/drr/ir_value.h" +#include "paddle/ap/include/drr/op_pattern_ctx.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/tensor_pattern_ctx.h" + +namespace ap::drr { + +axpr::TypeImpl> GetDrrCtxClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/drr_graph_descriptor.h b/paddle/ap/include/drr/drr_graph_descriptor.h new file mode 100644 index 00000000000000..1abd874bdc827f --- /dev/null +++ b/paddle/ap/include/drr/drr_graph_descriptor.h @@ -0,0 +1,442 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/topo_kind.h" +#include "paddle/ap/include/graph/graph_descriptor.h" +#include "paddle/ap/include/graph/node.h" + +namespace ap::drr { + +struct DefaultDrrGraphDescriptor { + using DrrNode = drr::Node; + using DrrGraphNode = graph::Node; + using NodeT = DrrGraphNode; + + using DrrNativeIrValue = ap::drr::NativeIrValue; + using DrrPackedIrValue = ap::drr::PackedIrValue; + using DrrNativeIrOp = ap::drr::NativeIrOp; + using DrrPackedIrOp = ap::drr::PackedIrOp; + using DrrOptPackedIrOp = ap::drr::OptPackedIrOp; + using DrrNativeIrOpOperand = ap::drr::NativeIrOpOperand; + using DrrPackedIrOpOperand = ap::drr::PackedIrOpOperand; + using DrrOptPackedIrOpOperand = ap::drr::OptPackedIrOpOperand; + using DrrNativeIrOpResult = ap::drr::NativeIrOpResult; + using DrrPackedIrOpResult = ap::drr::PackedIrOpResult; + using DrrOptPackedIrOpResult = ap::drr::OptPackedIrOpResult; + + template + adt::Result VisitUpstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + ADT_LET_CONST_REF(upstreams, node.UpstreamNodes()); + return upstreams.VisitNodes(DoEach); + } + + template + adt::Result VisitDownstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + ADT_LET_CONST_REF(downstreams, node.DownstreamNodes()); + return downstreams.VisitNodes(DoEach); + } + + template + adt::Result CastSoleUnignoredInput(const DrrNode& node) const { + std::optional opt_sole_input{}; + auto DoEachUpstream = + [&](const DrrGraphNode& upstream) -> adt::Result { + ADT_LET_CONST_REF(ignored, IgnoredNode(upstream)); + if (ignored) { + return adt::Ok{}; + } + ADT_LET_CONST_REF(drr_upstream, upstream.Get()); + ADT_LET_CONST_REF(casted, drr_upstream.template TryGet()); + ADT_CHECK(!opt_sole_input.has_value()); + opt_sole_input = casted; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitUpstreamNodes(node.node(), DoEachUpstream)); + ADT_CHECK(opt_sole_input.has_value()); + return opt_sole_input.value(); + } + + adt::Result GetSoleInput(const DrrNode& node) const { + std::optional opt_sole_input{}; + auto DoEachUpstream = + [&](const DrrGraphNode& upstream) -> adt::Result { + ADT_LET_CONST_REF(drr_upstream, upstream.Get()); + ADT_CHECK(!opt_sole_input.has_value()); + opt_sole_input = drr_upstream; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitUpstreamNodes(node.node(), DoEachUpstream)); + ADT_CHECK(opt_sole_input.has_value()); + return opt_sole_input.value(); + } + + template + adt::Result CastSoleUnignoredOutput(const DrrNode& node) const { + std::optional opt_sole_output{}; + auto DoEachDownstream = + [&](const DrrGraphNode& downstream) -> adt::Result { + ADT_LET_CONST_REF(ignored, IgnoredNode(downstream)); + if (ignored) { + return adt::Ok{}; + } + ADT_LET_CONST_REF(drr_downstream, downstream.Get()); + ADT_LET_CONST_REF(casted, drr_downstream.template TryGet()); + ADT_CHECK(!opt_sole_output.has_value()); + opt_sole_output = casted; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitDownstreamNodes(node.node(), DoEachDownstream)); + ADT_CHECK(opt_sole_output.has_value()); + return opt_sole_output.value(); + } + + adt::Result GetSoleOutput(const DrrNode& node) const { + std::optional opt_sole_output{}; + auto DoEachDownstream = + [&](const DrrGraphNode& downstream) -> adt::Result { + ADT_LET_CONST_REF(drr_downstream, downstream.Get()); + ADT_CHECK(!opt_sole_output.has_value()); + opt_sole_output = drr_downstream; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitDownstreamNodes(node.node(), DoEachDownstream)); + ADT_CHECK(opt_sole_output.has_value()); + return opt_sole_output.value(); + } + + adt::Result GetNumInputs(const DrrNode& node) const { + std::size_t num_inputs = 0; + auto DoEachUpstream = + [&](const DrrGraphNode& upstream) -> adt::Result { + ++num_inputs; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitUpstreamNodes(node.node(), DoEachUpstream)); + return num_inputs; + } + + adt::Result GetNumOutputs(const DrrNode& node) const { + std::size_t num_outputs = 0; + auto DoEachDownstream = + [&](const DrrGraphNode& downstream) -> adt::Result { + ++num_outputs; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitDownstreamNodes(node.node(), DoEachDownstream)); + return num_outputs; + } + + adt::Result GetSmallGraphNodeTopoCstr( + const NodeT& node) const { + ADT_LET_CONST_REF(drr_node, node.Get()); + return graph::SmallGraphNodeTopoCstr{drr_node.node_topo_cstr()}; + } + + adt::Result IgnoredNode(const NodeT& node) const { + ADT_LET_CONST_REF(drr_node, node.Get()); + return drr_node.Match( + [](const DrrPackedIrValue&) -> adt::Result { return true; }, + [&](const DrrPackedIrOpOperand& impl) -> adt::Result { + ADT_LET_CONST_REF(upstreams, impl->node.UpstreamNodes()); + ADT_CHECK(upstreams.size(), 1); + ADT_LET_CONST_REF(upstream_node, upstreams.Sole()); + return IgnoredNode(upstream_node); + }, + [&](const DrrPackedIrOpResult& impl) -> adt::Result { + ADT_LET_CONST_REF(downstreams, impl->node.DownstreamNodes()); + ADT_CHECK(downstreams.size(), 1); + ADT_LET_CONST_REF(downstream_node, downstreams.Sole()); + return IgnoredNode(downstream_node); + }, + [](const DrrNativeIrValue&) -> adt::Result { return false; }, + [](const DrrNativeIrOp&) -> adt::Result { return false; }, + [](const DrrPackedIrOp&) -> adt::Result { return false; }, + [](const DrrOptPackedIrOp&) -> adt::Result { return false; }, + [](const DrrNativeIrOpOperand&) -> adt::Result { return false; }, + [&](const DrrOptPackedIrOpOperand& impl) -> adt::Result { + ADT_LET_CONST_REF(upstreams, impl->node.UpstreamNodes()); + ADT_CHECK(upstreams.size(), 1); + ADT_LET_CONST_REF(upstream_node, upstreams.Sole()); + return IgnoredNode(upstream_node); + }, + [](const DrrNativeIrOpResult&) -> adt::Result { return false; }, + [&](const DrrOptPackedIrOpResult& impl) -> adt::Result { + ADT_LET_CONST_REF(downstreams, impl->node.DownstreamNodes()); + ADT_CHECK(downstreams.size(), 1); + ADT_LET_CONST_REF(downstream_node, downstreams.Sole()); + return IgnoredNode(downstream_node); + }); + } + + adt::Result IsOpNode(const NodeT& node) const { + ADT_LET_CONST_REF(drr_node, node.Get()); + return drr_node.Match( + [](const DrrNativeIrOp&) -> bool { return true; }, + [](const DrrPackedIrOp&) -> bool { return true; }, + [](const DrrOptPackedIrOp&) -> bool { return true; }, + [](const DrrNativeIrValue&) -> bool { return false; }, + [](const DrrPackedIrValue&) -> bool { return false; }, + [](const DrrNativeIrOpOperand&) -> bool { return false; }, + [](const DrrPackedIrOpOperand&) -> bool { return false; }, + [](const DrrOptPackedIrOpOperand&) -> bool { return false; }, + [](const DrrNativeIrOpResult&) -> bool { return false; }, + [](const DrrPackedIrOpResult&) -> bool { return false; }, + [](const DrrOptPackedIrOpResult&) -> bool { return false; }); + } + + adt::Result IsValueNode(const NodeT& node) const { + ADT_LET_CONST_REF(drr_node, node.Get()); + return drr_node.Match( + [](const DrrNativeIrOp&) -> bool { return false; }, + [](const DrrPackedIrOp&) -> bool { return false; }, + [](const DrrOptPackedIrOp&) -> bool { return false; }, + [](const DrrNativeIrValue&) -> bool { return true; }, + [](const DrrPackedIrValue&) -> bool { return true; }, + [](const DrrNativeIrOpOperand&) -> bool { return false; }, + [](const DrrPackedIrOpOperand&) -> bool { return false; }, + [](const DrrOptPackedIrOpOperand&) -> bool { return false; }, + [](const DrrNativeIrOpResult&) -> bool { return false; }, + [](const DrrPackedIrOpResult&) -> bool { return false; }, + [](const DrrOptPackedIrOpResult&) -> bool { return false; }); + } + + adt::Result TopoSatisfy( + const NodeT& node, + const graph::SmallGraphNodeTopoCstr& node_topo_cstr) const { + ADT_LET_CONST_REF(drr_node, node.Get()); + const graph::BigGraphNodeTopoCstr& drr_node_topo_cstr{ + drr_node.node_topo_cstr()}; + return drr_node_topo_cstr.TopoSatisfy(node_topo_cstr); + } +}; + +struct AllOperandAndResultDrrGraphDescriptor { + using DrrNode = drr::Node; + using DrrGraphNode = graph::Node; + using NodeT = DrrGraphNode; + + DefaultDrrGraphDescriptor backend_graph; + + template + adt::Result VisitUpstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + auto DoEachOpOrValue = [&](const NodeT& upstream) -> adt::Result { + ADT_LET_CONST_REF(is_op_node, backend_graph.IsOpNode(upstream)); + ADT_LET_CONST_REF(is_value_node, backend_graph.IsValueNode(upstream)); + ADT_CHECK(is_op_node || is_value_node); + return backend_graph.VisitUpstreamNodes(upstream, DoEach); + }; + return backend_graph.VisitUpstreamNodes(node, DoEachOpOrValue); + } + + template + adt::Result VisitDownstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + auto DoEachOpOrValue = + [&](const NodeT& downstream) -> adt::Result { + ADT_LET_CONST_REF(is_op_node, backend_graph.IsOpNode(downstream)); + ADT_LET_CONST_REF(is_value_node, backend_graph.IsValueNode(downstream)); + ADT_CHECK(is_op_node || is_value_node); + return backend_graph.VisitDownstreamNodes(downstream, DoEach); + }; + return backend_graph.VisitDownstreamNodes(node, DoEachOpOrValue); + } + + template + adt::Result CastSoleUnignoredInput(const DrrNode& node) const { + std::optional opt_sole_input{}; + auto DoEachUpstream = + [&](const DrrGraphNode& upstream) -> adt::Result { + ADT_LET_CONST_REF(ignored, IgnoredNode(upstream)); + if (ignored) { + return adt::Ok{}; + } + ADT_LET_CONST_REF(drr_upstream, upstream.Get()); + ADT_LET_CONST_REF(casted, drr_upstream.template TryGet()); + ADT_CHECK(!opt_sole_input.has_value()); + opt_sole_input = casted; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitUpstreamNodes(node.node(), DoEachUpstream)); + ADT_CHECK(opt_sole_input.has_value()); + return opt_sole_input.value(); + } + + adt::Result GetNumInputs(const DrrNode& node) const { + std::size_t num_inputs = 0; + auto DoEachUpstream = + [&](const DrrGraphNode& upstream) -> adt::Result { + ++num_inputs; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitUpstreamNodes(node.node(), DoEachUpstream)); + return num_inputs; + } + + adt::Result GetNumOutputs(const DrrNode& node) const { + std::size_t num_outputs = 0; + auto DoEachDownstream = + [&](const DrrGraphNode& downstream) -> adt::Result { + ++num_outputs; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitDownstreamNodes(node.node(), DoEachDownstream)); + return num_outputs; + } + + adt::Result GetSmallGraphNodeTopoCstr( + const NodeT& node) const { + return backend_graph.GetSmallGraphNodeTopoCstr(node); + } + + adt::Result IgnoredNode(const NodeT& node) const { + ADT_LET_CONST_REF(is_op_node, backend_graph.IsOpNode(node)); + ADT_LET_CONST_REF(is_value_node, backend_graph.IsValueNode(node)); + if (is_op_node || is_value_node) { + return true; + } + return backend_graph.IgnoredNode(node); + } + + adt::Result IsOpNode(const NodeT& node) const { + return backend_graph.IsOpNode(node); + } + + adt::Result TopoSatisfy( + const NodeT& node, + const graph::SmallGraphNodeTopoCstr& node_topo_cstr) const { + return backend_graph.TopoSatisfy(node, node_topo_cstr); + } +}; + +struct NativeOperandAndResultDrrGraphDescriptor { + using DrrNode = drr::Node; + using DrrGraphNode = graph::Node; + using NodeT = DrrGraphNode; + + AllOperandAndResultDrrGraphDescriptor backend_graph; + + template + adt::Result VisitUpstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + ADT_LET_CONST_REF(is_node_native, IsNative(node)); + ADT_CHECK(is_node_native); + auto VisitEachNative = [&](const NodeT& upstream) -> adt::Result { + ADT_LET_CONST_REF(is_upstream_native, IsNative(upstream)); + ADT_CHECK(!is_upstream_native); + return backend_graph.VisitUpstreamNodes(upstream, DoEach); + }; + auto VisitEachPacked = [&](const NodeT& upstream) -> adt::Result { + ADT_LET_CONST_REF(is_upstream_native, IsNative(upstream)); + ADT_CHECK(!is_upstream_native); + return backend_graph.VisitUpstreamNodes(upstream, VisitEachNative); + }; + auto DoEachOperandOrResult = + [&](const NodeT& upstream) -> adt::Result { + ADT_LET_CONST_REF(is_native, IsNative(upstream)); + if (is_native) { + return DoEach(upstream); + } else { + return backend_graph.VisitUpstreamNodes(upstream, VisitEachPacked); + } + }; + return backend_graph.VisitUpstreamNodes(node, DoEachOperandOrResult); + } + + template + adt::Result VisitDownstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + ADT_LET_CONST_REF(is_node_native, IsNative(node)); + ADT_CHECK(is_node_native); + auto VisitEachNative = + [&](const NodeT& downstream) -> adt::Result { + ADT_LET_CONST_REF(is_downstream_native, IsNative(downstream)); + ADT_CHECK(!is_downstream_native); + return backend_graph.VisitDownstreamNodes(downstream, DoEach); + }; + auto VisitEachPacked = + [&](const NodeT& downstream) -> adt::Result { + ADT_LET_CONST_REF(is_downstream_native, IsNative(downstream)); + ADT_CHECK(!is_downstream_native); + return backend_graph.VisitDownstreamNodes(downstream, VisitEachNative); + }; + auto DoEachOperandOrResult = + [&](const NodeT& downstream) -> adt::Result { + ADT_LET_CONST_REF(is_native, IsNative(downstream)); + if (is_native) { + return DoEach(downstream); + } else { + return backend_graph.VisitDownstreamNodes(downstream, VisitEachPacked); + } + }; + return backend_graph.VisitDownstreamNodes(node, DoEachOperandOrResult); + } + + adt::Result GetSmallGraphNodeTopoCstr( + const NodeT& node) const { + return backend_graph.GetSmallGraphNodeTopoCstr(node); + } + + adt::Result IgnoredNode(const NodeT& node) const { + ADT_LET_CONST_REF(is_native, IsNative(node)); + if (!is_native) { + return true; + } + return backend_graph.IgnoredNode(node); + } + + adt::Result IsOpNode(const NodeT& node) const { + return backend_graph.IsOpNode(node); + } + + adt::Result TopoSatisfy( + const NodeT& node, + const graph::SmallGraphNodeTopoCstr& node_topo_cstr) const { + return backend_graph.TopoSatisfy(node, node_topo_cstr); + } + + adt::Result IsNative(const NodeT& node) const { + ADT_LET_CONST_REF(drr_node, node.Get()); + return drr_node.Match( + [&](const NativeIrOpOperand&) -> bool { return true; }, + [&](const NativeIrOpResult&) -> bool { return true; }, + [&](const auto&) -> bool { return false; }); + } +}; + +} // namespace ap::drr + +namespace ap::graph { + +template <> +struct GraphDescriptor, drr::topo_kind::Default> + : public drr::DefaultDrrGraphDescriptor {}; + +template <> +struct GraphDescriptor, + drr::topo_kind::AllOperandAndResult> + : public drr::AllOperandAndResultDrrGraphDescriptor {}; + +template <> +struct GraphDescriptor, + drr::topo_kind::NativeOperandAndResult> + : public drr::NativeOperandAndResultDrrGraphDescriptor {}; + +} // namespace ap::graph diff --git a/paddle/ap/include/drr/drr_interpreter.h b/paddle/ap/include/drr/drr_interpreter.h new file mode 100644 index 00000000000000..0fc33a22fbe590 --- /dev/null +++ b/paddle/ap/include/drr/drr_interpreter.h @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/interpreter.h" +#include "paddle/ap/include/drr/value.h" +#include "paddle/ap/include/registry/abstract_drr_pass_registry_item.h" + +namespace ap::drr { + +class DrrInterpreter { + public: + explicit DrrInterpreter( + const axpr::TypeImpl>& + backend_ir_ctx, + const std::weak_ptr& + circlable_ref_list); + + using Function = ap::axpr::Value; + + using DrrNode = ap::drr::Node; + using DrrCtx = ap::drr::DrrCtx; + + ap::adt::Result Interpret(const Function& function, + const std::vector& args) { + return interpreter_.Interpret(function, args); + } + + ap::adt::Result InterpretDrrCtxMaker( + const Function& lambda, const std::vector& args); + + ap::adt::Result InterpretPass( + const Function& function, const std::string& abstract_drr_pass_name); + + ap::adt::Result InterpretPass( + const ap::axpr::ClassAttrs& cls); + + ap::adt::Result CreateDrrCtxByDrrPassObj( + const ap::axpr::Value& drr_pass_obj); + + private: + axpr::Interpreter interpreter_; +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/drr_node_descriptor.h b/paddle/ap/include/drr/drr_node_descriptor.h new file mode 100644 index 00000000000000..b4c20fed923966 --- /dev/null +++ b/paddle/ap/include/drr/drr_node_descriptor.h @@ -0,0 +1,129 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/value.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_descriptor.h" + +namespace ap::drr { + +struct DrrNodeDescriptor { + using DrrNode = drr::Node; + + using DrrNativeIrValue = ap::drr::NativeIrValue; + using DrrPackedIrValue = ap::drr::PackedIrValue; + using DrrNativeIrOp = ap::drr::NativeIrOp; + using DrrPackedIrOp = ap::drr::PackedIrOp; + using DrrOptPackedIrOp = ap::drr::OptPackedIrOp; + using DrrNativeIrOpOperand = ap::drr::NativeIrOpOperand; + using DrrPackedIrOpOperand = ap::drr::PackedIrOpOperand; + using DrrOptPackedIrOpOperand = ap::drr::OptPackedIrOpOperand; + using DrrNativeIrOpResult = ap::drr::NativeIrOpResult; + using DrrPackedIrOpResult = ap::drr::PackedIrOpResult; + using DrrOptPackedIrOpResult = ap::drr::OptPackedIrOpResult; + + std::string DebugId(const graph::Node& node) { + const auto& opt_drr_node = node.Get(); + if (opt_drr_node.HasError()) { + return std::to_string(node.node_id().value()); + } + const auto& drr_node = opt_drr_node.GetOkValue(); + return drr_node.Match( + [&](const DrrNativeIrValue& ir_value) -> std::string { + return ir_value->name; + }, + [&](const DrrPackedIrValue& ir_value) -> std::string { + return ir_value->name; + }, + [&](const DrrNativeIrOp& ir_op) -> std::string { + return ir_op->op_declare->op_name + "[" + ir_op->name + "]"; + }, + [&](const DrrPackedIrOp& ir_op) -> std::string { + return ir_op->op_declare->op_name + "[" + ir_op->name + "]"; + }, + [&](const DrrOptPackedIrOp& ir_op) -> std::string { + return std::string("opt-") + ir_op->op_declare->op_name + "[" + + ir_op->name + "]"; + }, + [&](const DrrNativeIrOpOperand& ir_op_operand) -> std::string { + return EdgeDebugId(node); + }, + [&](const DrrPackedIrOpOperand& ir_op_operand) -> std::string { + return EdgeDebugId(node); + }, + [&](const DrrOptPackedIrOpOperand& ir_op_operand) -> std::string { + return EdgeDebugId(node); + }, + [&](const DrrNativeIrOpResult& ir_op_result) -> std::string { + return EdgeDebugId(node); + }, + [&](const DrrPackedIrOpResult& ir_op_result) -> std::string { + return EdgeDebugId(node); + }, + [&](const DrrOptPackedIrOpResult& ir_op_result) -> std::string { + return EdgeDebugId(node); + }); + } + + std::string EdgeDebugId(const graph::Node& node) { + const auto& opt_src_and_dst = GetSrcAndDst(node); + if (!opt_src_and_dst.has_value()) { + return std::string("invalid_edge_") + + std::to_string(node.node_id().value()); + } + const auto& [src, dst] = opt_src_and_dst.value(); + return DebugId(src) + "->" + DebugId(dst); + } + + struct SrcAndDst { + graph::Node src; + graph::Node dst; + }; + + std::optional GetSrcAndDst(const graph::Node& node) { + const auto& opt_src_and_dst = TryGetSrcAndDst(node); + if (opt_src_and_dst.HasError()) { + return std::nullopt; + } + return opt_src_and_dst.GetOkValue(); + } + + adt::Result TryGetSrcAndDst(const graph::Node& node) { + ADT_LET_CONST_REF(upstreams, node.UpstreamNodes()); + ADT_LET_CONST_REF(downstreams, node.DownstreamNodes()); + ADT_LET_CONST_REF(src, upstreams.Sole()); + ADT_LET_CONST_REF(dst, downstreams.Sole()); + return SrcAndDst{src, dst}; + } + + adt::Result AttrsSatisfyIfBothAreOpsOrValues( + const drr::Node& node, const graph::Node& drr_node) { + return adt::errors::NotImplementedError{ + "NodeDescriptor>::AttrSatisfy() not " + "implemented"}; + } +}; + +} // namespace ap::drr + +namespace ap::graph { + +template <> +struct NodeDescriptor> : public drr::DrrNodeDescriptor { +}; + +} // namespace ap::graph diff --git a/paddle/ap/include/drr/drr_pass_type.h b/paddle/ap/include/drr/drr_pass_type.h new file mode 100644 index 00000000000000..d6a48a0392741b --- /dev/null +++ b/paddle/ap/include/drr/drr_pass_type.h @@ -0,0 +1,39 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" + +namespace ap::drr { + +struct AbstractDrrPassType : public std::monostate { + using std::monostate::monostate; +}; +struct ReifiedDrrPassType : public std::monostate { + using std::monostate::monostate; +}; +struct AccessTopoDrrPassType : public std::monostate { + using std::monostate::monostate; +}; + +using DrrPassTypeImpl = std:: + variant; + +struct DrrPassType : public DrrPassTypeImpl { + using DrrPassTypeImpl::DrrPassTypeImpl; + ADT_DEFINE_VARIANT_METHODS(DrrPassTypeImpl); +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/drr_pass_type_helper.h b/paddle/ap/include/drr/drr_pass_type_helper.h new file mode 100644 index 00000000000000..1a89537c595640 --- /dev/null +++ b/paddle/ap/include/drr/drr_pass_type_helper.h @@ -0,0 +1,41 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/drr/drr_pass_type.h" + +namespace ap::drr { + +struct DrrPassTypeHelper { + bool SupportReifying(const std::optional& type) const { + if (!type.has_value()) return false; + return type.value().Match( + [&](const AbstractDrrPassType&) { return true; }, + [&](const ReifiedDrrPassType&) { return false; }, + [&](const AccessTopoDrrPassType&) { return false; }); + } + + bool SupportOptionalPackedOp(const std::optional& type) const { + if (!type.has_value()) return false; + return type.value().Match( + [&](const AbstractDrrPassType&) { return true; }, + [&](const ReifiedDrrPassType&) { return false; }, + [&](const AccessTopoDrrPassType&) { return false; }); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/drr_value.h b/paddle/ap/include/drr/drr_value.h new file mode 100644 index 00000000000000..9bc5d49a8c382e --- /dev/null +++ b/paddle/ap/include/drr/drr_value.h @@ -0,0 +1,80 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/drr_ctx.h" +#include "paddle/ap/include/drr/native_ir_op.h" +#include "paddle/ap/include/drr/native_ir_value.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/op_pattern_ctx.h" +#include "paddle/ap/include/drr/packed_ir_op.h" +#include "paddle/ap/include/drr/packed_ir_value.h" +#include "paddle/ap/include/drr/result_pattern_ctx.h" +#include "paddle/ap/include/drr/source_pattern_ctx.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/tensor_pattern_ctx.h" +#include "paddle/ap/include/drr/unbound_ir_value.h" +#include "paddle/ap/include/drr/unbound_native_ir_op.h" +#include "paddle/ap/include/drr/unbound_opt_packed_ir_op.h" +#include "paddle/ap/include/drr/unbound_packed_ir_op.h" +#include "paddle/ap/include/drr/unbound_packed_ir_value.h" +#include "paddle/ap/include/graph/tags.h" + +namespace ap::drr { + +using DrrValueImpl = std::variant, + UnboundPackedIrValue, + NativeIrOp, + PackedIrOp, + OptPackedIrOp, + tSrcPtn>, + tSrcPtn>, + OptPackedIrOpDeclare, + tSrcPtn>, + tSrcPtn>, + UnboundOptPackedIrOp, + tSrcPtn>, + tSrcPtn>, + tSrcPtn, + tSrcPtn, + tStarred>>, + SourcePatternCtx, + tResPtn>, + tResPtn>, + tResPtn>, + tResPtn>, + tResPtn>, + tResPtn>, + tResPtn, + tResPtn, + tStarred>>, + ResultPatternCtx, + DrrCtx>; + +struct DrrValue : public DrrValueImpl { + using DrrValueImpl::DrrValueImpl; + ADT_DEFINE_VARIANT_METHODS(DrrValueImpl); + + template + decltype(auto) DrrValueMatch(Args&&... args) const { + return Match(std::forward(args)...); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/drr_value_helper.h b/paddle/ap/include/drr/drr_value_helper.h new file mode 100644 index 00000000000000..1cbf900f4c03c7 --- /dev/null +++ b/paddle/ap/include/drr/drr_value_helper.h @@ -0,0 +1,91 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/drr_value.h" + +namespace ap::drr { + +struct DrrValueHelper { + using This = DrrValueHelper; + + DrrValue CastFromAxprValue(const axpr::Value& axpr_val) { + return axpr_val.Match( + [&](const axpr::BuiltinClassInstance& instance) + -> DrrValue { return CastInstanceToDrrValue(instance); }, + [&](const auto&) -> DrrValue { return axpr_val; }); + } + + axpr::Value CastToAxprValue(const DrrValue& drr_value) { + return drr_value.Match( + [&](const axpr::Value& axpr_val) -> axpr::Value { return axpr_val; }, + [&](const auto& impl) -> axpr::Value { + using T = std::decay_t; + using TT = drr::Type; + return TT::GetClass().New(impl); + }); + } + + private: + using AxprInstanceToDrrValueConverter = + DrrValue (*)(const axpr::BuiltinClassInstance&); + using AxprInstanceToDrrValueMap = + std::map; + + DrrValue CastInstanceToDrrValue( + const axpr::BuiltinClassInstance& instance) { + const AxprInstanceToDrrValueMap& map = GetAxprInstanceToDrrValueMap(); + const auto& iter = map.find(instance.instance.type()); + if (iter == map.end()) { + return axpr::Value{instance}; + } else { + return iter->second(instance); + } + } + + const AxprInstanceToDrrValueMap& GetAxprInstanceToDrrValueMap() { + static const AxprInstanceToDrrValueMap map(MakeAxprInstanceToDrrValueMap()); + return map; + } + + AxprInstanceToDrrValueMap MakeAxprInstanceToDrrValueMap() { + AxprInstanceToDrrValueMap map; + InsertEntries(&map); + return map; + } + + template + void InsertEntries(AxprInstanceToDrrValueMap* map) { + if constexpr (start_idx >= std::variant_size_v) { + return; + } else { + using Impl = typename std::variant_alternative_t; + (*map)[typeid(Impl)] = + &This::template ConvertAxprInstanceToDrrValue; + InsertEntries(map); + } + } + + template + static DrrValue ConvertAxprInstanceToDrrValue( + const axpr::BuiltinClassInstance& instance) { + return std::any_cast(instance.instance); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/ir_op.h b/paddle/ap/include/drr/ir_op.h new file mode 100644 index 00000000000000..37a6880e53528a --- /dev/null +++ b/paddle/ap/include/drr/ir_op.h @@ -0,0 +1,63 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/drr/native_ir_op.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/opt_packed_ir_op.h" +#include "paddle/ap/include/drr/packed_ir_op.h" +#include "paddle/ap/include/drr/unbound_native_ir_op.h" +#include "paddle/ap/include/drr/unbound_opt_packed_ir_op.h" +#include "paddle/ap/include/drr/unbound_packed_ir_op.h" + +namespace ap::drr { + +using IrOpImpl = std::variant, + PackedIrOp, + OptPackedIrOp, + UnboundNativeIrOp, + UnboundPackedIrOp, + UnboundOptPackedIrOp>; + +struct IrOp : public IrOpImpl { + using IrOpImpl::IrOpImpl; + ADT_DEFINE_VARIANT_METHODS(IrOpImpl); + + const std::string& op_name() const { + using RetT = const std::string&; + return Match( + [&](const NativeIrOp& impl) -> RetT { + return impl->op_declare->op_name; + }, + [&](const PackedIrOp& impl) -> RetT { + return impl->op_declare->op_name; + }, + [&](const OptPackedIrOp& impl) -> RetT { + return impl->op_declare->op_name; + }, + [&](const UnboundNativeIrOp& impl) -> RetT { + return impl->op_declare->op_name; + }, + [&](const UnboundPackedIrOp& impl) -> RetT { + return impl->op_declare->op_name; + }, + [&](const UnboundOptPackedIrOp& impl) -> RetT { + return impl->op_declare->op_name; + }); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/ir_value.h b/paddle/ap/include/drr/ir_value.h new file mode 100644 index 00000000000000..cb6b277e557d68 --- /dev/null +++ b/paddle/ap/include/drr/ir_value.h @@ -0,0 +1,55 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/drr/native_ir_value.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/packed_ir_value.h" + +namespace ap::drr { + +using IrValueImpl = + std::variant, PackedIrValue>; + +struct IrValue : public IrValueImpl { + using IrValueImpl::IrValueImpl; + ADT_DEFINE_VARIANT_METHODS(IrValueImpl); + + const graph::Node& node() const { + return Match([](const auto& impl) -> const graph::Node& { + return impl->node; + }); + } + + static std::optional OptCastFrom(const drr::Node& drr_node) { + using RetT = std::optional; + return drr_node.Match( + [](const NativeIrValue& ir_value) -> RetT { + return IrValue{ir_value}; + }, + [](const PackedIrValue& ir_value) -> RetT { + return IrValue{ir_value}; + }, + [](const auto&) -> RetT { return std::nullopt; }); + } + + const std::string& name() const { + return Match( + [](const auto& impl) -> const std::string& { return impl->name; }); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/native_ir_op.h b/paddle/ap/include/drr/native_ir_op.h new file mode 100644 index 00000000000000..d113f90db27426 --- /dev/null +++ b/paddle/ap/include/drr/native_ir_op.h @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/native_ir_op_declare.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/type.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_topo_cstr.h" + +namespace ap::drr { + +template +struct NativeIrOpImpl { + graph::Node node; + NativeIrOpDeclare op_declare; + std::string name; + + bool operator==(const NativeIrOpImpl& other) const { + return this->node == other.node && this->op_declare == other.op_declare && + this->name == other.name; + } + + graph::NativeIrOpTopoCstr node_topo_cstr() const { + return graph::NativeIrOpTopoCstr{this->op_declare->op_name}; + } +}; + +template +ADT_DEFINE_RC(NativeIrOp, NativeIrOpImpl); + +axpr::TypeImpl> GetNativeIrOpClass(); + +template +struct Type> : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "NativeIrOp"; } + + static axpr::TypeImpl> GetClass() { + return GetNativeIrOpClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/native_ir_op_declare.h b/paddle/ap/include/drr/native_ir_op_declare.h new file mode 100644 index 00000000000000..c287adafc996f0 --- /dev/null +++ b/paddle/ap/include/drr/native_ir_op_declare.h @@ -0,0 +1,71 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/type.h" + +namespace ap::drr { + +struct OpPatternCtxImpl; + +template +struct NativeIrOpDeclareImpl { + std::string op_name; + std::weak_ptr op_pattern_ctx; + axpr::AttrMap attr_map; + + bool operator==(const NativeIrOpDeclareImpl& other) const { + return this->op_name == other.op_name && + this->op_pattern_ctx.lock() == other.op_pattern_ctx.lock() && + this->attr_map == other.attr_map; + } +}; + +template +ADT_DEFINE_RC(NativeIrOpDeclare, NativeIrOpDeclareImpl); + +axpr::TypeImpl> +GetSrcPtnNativeIrOpDeclareClass(); + +template +struct Type>> + : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "SrcPtnNativeIrOpDeclare"; } + + static axpr::TypeImpl> GetClass() { + return GetSrcPtnNativeIrOpDeclareClass(); + } +}; + +axpr::TypeImpl> +GetResPtnNativeIrOpDeclareClass(); + +template +struct Type>> + : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "ResPtnNativeIrOpDeclare"; } + + static axpr::TypeImpl> GetClass() { + return GetResPtnNativeIrOpDeclareClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/native_ir_op_declare_method_class.h b/paddle/ap/include/drr/native_ir_op_declare_method_class.h new file mode 100644 index 00000000000000..20ebddcdf660e4 --- /dev/null +++ b/paddle/ap/include/drr/native_ir_op_declare_method_class.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/native_ir_op_declare.h" +#include "paddle/ap/include/drr/tags.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetSrcPtnNativeIrOpDeclareClass(); + +axpr::TypeImpl> +GetResPtnNativeIrOpDeclareClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/native_ir_op_method_class.h b/paddle/ap/include/drr/native_ir_op_method_class.h new file mode 100644 index 00000000000000..061ac953526ed5 --- /dev/null +++ b/paddle/ap/include/drr/native_ir_op_method_class.h @@ -0,0 +1,28 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/native_ir_op.h" +#include "paddle/ap/include/drr/tags.h" + +namespace ap::drr { + +axpr::TypeImpl> GetNativeIrOpClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/native_ir_op_operand.h b/paddle/ap/include/drr/native_ir_op_operand.h new file mode 100644 index 00000000000000..89b21b5f073ce8 --- /dev/null +++ b/paddle/ap/include/drr/native_ir_op_operand.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_topo_cstr.h" + +namespace ap::drr { + +template +struct NativeIrOpOperandImpl { + graph::Node node; + std::size_t index; + + bool operator==(const NativeIrOpOperandImpl& other) const { + return this->node == other.node && this->index == other.index; + } + + graph::NativeIrOpOperandTopoCstr node_topo_cstr() const { + return graph::NativeIrOpOperandTopoCstr{index}; + } +}; + +template +ADT_DEFINE_RC(NativeIrOpOperand, NativeIrOpOperandImpl); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/native_ir_op_result.h b/paddle/ap/include/drr/native_ir_op_result.h new file mode 100644 index 00000000000000..3dae1e02c62c52 --- /dev/null +++ b/paddle/ap/include/drr/native_ir_op_result.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_topo_cstr.h" + +namespace ap::drr { + +template +struct NativeIrOpResultImpl { + graph::Node node; + std::size_t index; + + bool operator==(const NativeIrOpResultImpl& other) const { + return this->node == other.node && this->index == other.index; + } + + graph::NativeIrOpResultTopoCstr node_topo_cstr() const { + return graph::NativeIrOpResultTopoCstr{index}; + } +}; + +template +ADT_DEFINE_RC(NativeIrOpResult, NativeIrOpResultImpl); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/native_ir_value.h b/paddle/ap/include/drr/native_ir_value.h new file mode 100644 index 00000000000000..409bbc7baf3442 --- /dev/null +++ b/paddle/ap/include/drr/native_ir_value.h @@ -0,0 +1,72 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/type.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_topo_cstr.h" + +namespace ap::drr { + +struct TensorPatternCtxImpl; + +template +struct NativeIrValueImpl { + graph::Node node; + std::string name; + std::weak_ptr tensor_pattern_ctx; + + bool operator==(const NativeIrValueImpl& other) const { + return this->node == other.node && this->name == other.name && + this->tensor_pattern_ctx.lock() == other.tensor_pattern_ctx.lock(); + } + + graph::NativeIrValueTopoCstr node_topo_cstr() const { + return graph::NativeIrValueTopoCstr{}; + } +}; + +template +ADT_DEFINE_RC(NativeIrValue, NativeIrValueImpl); + +axpr::TypeImpl> +GetSrcPtnNativeIrValueClass(); + +template +struct Type>> : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "SrcPtnNativeIrValue"; } + + static axpr::TypeImpl> GetClass() { + return GetSrcPtnNativeIrValueClass(); + } +}; + +axpr::TypeImpl> +GetResPtnNativeIrValueClass(); + +template +struct Type>> : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "ResPtnNativeIrValue"; } + + static axpr::TypeImpl> GetClass() { + return GetResPtnNativeIrValueClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/native_ir_value_method_class.h b/paddle/ap/include/drr/native_ir_value_method_class.h new file mode 100644 index 00000000000000..d64ba95b5adc34 --- /dev/null +++ b/paddle/ap/include/drr/native_ir_value_method_class.h @@ -0,0 +1,33 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/native_ir_value.h" +#include "paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h" +#include "paddle/ap/include/drr/tags.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetSrcPtnNativeIrValueClass(); + +axpr::TypeImpl> +GetResPtnNativeIrValueClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/node.h b/paddle/ap/include/drr/node.h new file mode 100644 index 00000000000000..40584788eaeb85 --- /dev/null +++ b/paddle/ap/include/drr/node.h @@ -0,0 +1,62 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/drr/native_ir_op.h" +#include "paddle/ap/include/drr/native_ir_op_operand.h" +#include "paddle/ap/include/drr/native_ir_op_result.h" +#include "paddle/ap/include/drr/native_ir_value.h" +#include "paddle/ap/include/drr/opt_packed_ir_op.h" +#include "paddle/ap/include/drr/opt_packed_ir_op_operand.h" +#include "paddle/ap/include/drr/opt_packed_ir_op_result.h" +#include "paddle/ap/include/drr/packed_ir_op.h" +#include "paddle/ap/include/drr/packed_ir_op_operand.h" +#include "paddle/ap/include/drr/packed_ir_op_result.h" +#include "paddle/ap/include/drr/packed_ir_value.h" +#include "paddle/ap/include/graph/node_topo_cstr.h" + +namespace ap::drr { + +template +using NodeImpl = std::variant, + NativeIrOp, + NativeIrOpOperand, + NativeIrOpResult, + PackedIrValue, + PackedIrOp, + PackedIrOpOperand, + PackedIrOpResult, + OptPackedIrOp, + OptPackedIrOpOperand, + OptPackedIrOpResult>; + +struct Node : public NodeImpl { + using NodeImpl::NodeImpl; + ADT_DEFINE_VARIANT_METHODS(NodeImpl); + + const graph::Node& node() const { + return Match([](const auto& impl) -> const graph::Node& { + return impl->node; + }); + } + + graph::NodeTopoCstr node_topo_cstr() const { + return Match([](const auto& impl) -> graph::NodeTopoCstr { + return impl->node_topo_cstr(); + }); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/op_pattern_ctx.h b/paddle/ap/include/drr/op_pattern_ctx.h new file mode 100644 index 00000000000000..319c92d4534560 --- /dev/null +++ b/paddle/ap/include/drr/op_pattern_ctx.h @@ -0,0 +1,71 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/ir_op.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/type.h" +#include "paddle/ap/include/graph/node_arena.h" +#include "paddle/ap/include/graph/tags.h" + +namespace ap::drr { + +struct DrrCtxImpl; + +struct OpPatternCtxImpl { + std::shared_ptr> node_arena; + mutable std::map uid2ir_op; + std::weak_ptr drr_ctx; + + bool operator==(const OpPatternCtxImpl& other) const { + return this == &other; + } +}; + +ADT_DEFINE_RC(OpPatternCtx, OpPatternCtxImpl); + +axpr::TypeImpl> +GetSrcPtnOpPatternCtxClass(); + +template <> +struct Type> : public std::monostate { + using value_type = drr::tSrcPtn; + + const char* Name() const { return "SrcPtnOpPatternCtx"; } + + static axpr::TypeImpl> GetClass() { + return GetSrcPtnOpPatternCtxClass(); + } +}; + +axpr::TypeImpl> +GetResPtnOpPatternCtxClass(); + +template <> +struct Type> : public std::monostate { + using value_type = drr::tResPtn; + + const char* Name() const { return "ResPtnOpPatternCtx"; } + + static axpr::TypeImpl> GetClass() { + return GetResPtnOpPatternCtxClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h b/paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h new file mode 100644 index 00000000000000..21e5f2abde7aa6 --- /dev/null +++ b/paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h @@ -0,0 +1,364 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/type.h" + +#include "paddle/ap/include/drr/ir_op.h" +#include "paddle/ap/include/drr/ir_value.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/op_pattern_ctx.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/tensor_pattern_ctx.h" +#include "paddle/ap/include/drr/unbound_ir_value.h" +#include "paddle/ap/include/drr/unbound_native_ir_op.h" +#include "paddle/ap/include/drr/unbound_opt_packed_ir_op.h" +#include "paddle/ap/include/drr/unbound_packed_ir_op.h" +#include "paddle/ap/include/drr/unbound_packed_ir_value.h" + +namespace ap::drr { + +struct OpTensorPatternCtxHelper { + using OpPtnCtx = OpPatternCtx; + using TensorPtnCtx = TensorPatternCtx; + + template + adt::Result> GetOptType(const IrValueT& ir_value) { + ADT_LET_CONST_REF(tensor_pattern_ctx, + adt::WeakPtrLock(ir_value->tensor_pattern_ctx)); + const auto& iter = tensor_pattern_ctx->uid2type.find(ir_value->name); + if (iter == tensor_pattern_ctx->uid2type.end()) { + return std::nullopt; + } + return iter->second; + } + + template + adt::Result SetType(const IrValueT& ir_value, + const axpr::Value& type) { + ADT_LET_CONST_REF(tensor_pattern_ctx, + adt::WeakPtrLock(ir_value->tensor_pattern_ctx)); + tensor_pattern_ctx->uid2type[ir_value->name] = type; + return adt::Ok{}; + } + + adt::Result ConnectIrOpAndIrValue( + const NativeIrOp& native_ir_op, + const adt::List>& inputs, + const adt::List>& outputs) { + ADT_LET_CONST_REF(op_upstream_nodes, native_ir_op->node.UpstreamNodes()); + ADT_CHECK(op_upstream_nodes.size() == 0); + ADT_LET_CONST_REF(op_downstream_nodes, + native_ir_op->node.DownstreamNodes()); + ADT_CHECK(op_downstream_nodes.size() == 0); + ADT_LET_CONST_REF( + op_pattern_ctx, + adt::WeakPtrLock(native_ir_op->op_declare->op_pattern_ctx)); + const auto& node_arena = op_pattern_ctx->node_arena; + for (int i = 0; i < inputs->size(); ++i) { + const auto& native_ir_op_operand = node_arena->New([&](const auto& node) { + return NativeIrOpOperand{node, i}; + }); + ADT_RETURN_IF_ERR( + inputs->at(i)->node.ConnectTo(native_ir_op_operand.node(), + graph::UnindexedTag{}, + graph::IndexedTag{})); + ADT_RETURN_IF_ERR(native_ir_op_operand.node().ConnectTo( + native_ir_op->node, + graph::IndexedTag{}, + graph::IndexedTag{})); + } + for (int i = 0; i < outputs->size(); ++i) { + ADT_LET_CONST_REF(output_upstream_nodes, + outputs->at(i)->node.UpstreamNodes()); + ADT_CHECK(output_upstream_nodes.size() == 0); + const auto& native_ir_op_result = node_arena->New([&](const auto& node) { + return NativeIrOpResult{node, i}; + }); + ADT_RETURN_IF_ERR( + native_ir_op->node.ConnectTo(native_ir_op_result.node(), + graph::IndexedTag{}, + graph::IndexedTag{})); + ADT_RETURN_IF_ERR(native_ir_op_result.node().ConnectTo( + outputs->at(i)->node, + graph::IndexedTag{}, + graph::IndexedTag{})); + } + SetIrOpByUid(op_pattern_ctx, native_ir_op->name, native_ir_op); + return adt::Nothing{}; + } + + adt::Result ConnectIrOpAndIrValue( + const PackedIrOp& packed_ir_op, + const adt::List& inputs, + const adt::List& outputs) { + ADT_LET_CONST_REF(op_upstream_nodes, packed_ir_op->node.UpstreamNodes()); + ADT_CHECK(op_upstream_nodes.size() == 0); + ADT_LET_CONST_REF(op_downstream_nodes, + packed_ir_op->node.DownstreamNodes()); + ADT_CHECK(op_downstream_nodes.size() == 0); + ADT_LET_CONST_REF( + op_pattern_ctx, + adt::WeakPtrLock(packed_ir_op->op_declare->op_pattern_ctx)); + const auto& node_arena = op_pattern_ctx->node_arena; + for (int i = 0; i < inputs->size(); ++i) { + const auto& packed_ir_op_operand = node_arena->New([&](const auto& node) { + return PackedIrOpOperand{node, i}; + }); + ADT_RETURN_IF_ERR( + inputs->at(i).node().ConnectTo(packed_ir_op_operand.node(), + graph::UnindexedTag{}, + graph::IndexedTag{})); + ADT_RETURN_IF_ERR(packed_ir_op_operand.node().ConnectTo( + packed_ir_op->node, + graph::IndexedTag{}, + graph::UnindexedTag{})); + } + for (int i = 0; i < outputs->size(); ++i) { + ADT_LET_CONST_REF(output_upstream_nodes, + outputs->at(i).node().UpstreamNodes()); + ADT_CHECK(output_upstream_nodes.size() == 0); + const auto& packed_ir_op_result = node_arena->New([&](const auto& node) { + return PackedIrOpResult{node, i}; + }); + ADT_RETURN_IF_ERR( + packed_ir_op->node.ConnectTo(packed_ir_op_result.node(), + graph::UnindexedTag{}, + graph::IndexedTag{})); + ADT_RETURN_IF_ERR(packed_ir_op_result.node().ConnectTo( + outputs->at(i).node(), + graph::IndexedTag{}, + graph::IndexedTag{})); + } + SetIrOpByUid(op_pattern_ctx, packed_ir_op->name, packed_ir_op); + return adt::Nothing{}; + } + + adt::Result ConnectIrOpAndIrValue( + const OptPackedIrOp& packed_ir_op, + const adt::List& inputs, + const adt::List& outputs) { + ADT_LET_CONST_REF(op_upstream_nodes, packed_ir_op->node.UpstreamNodes()); + ADT_CHECK(op_upstream_nodes.size() == 0); + ADT_LET_CONST_REF(op_downstream_nodes, + packed_ir_op->node.DownstreamNodes()); + ADT_CHECK(op_downstream_nodes.size() == 0); + ADT_LET_CONST_REF( + op_pattern_ctx, + adt::WeakPtrLock(packed_ir_op->op_declare->op_pattern_ctx)); + const auto& node_arena = op_pattern_ctx->node_arena; + for (int i = 0; i < inputs->size(); ++i) { + const auto& packed_ir_op_operand = node_arena->New([&](const auto& node) { + return OptPackedIrOpOperand{node, i}; + }); + ADT_RETURN_IF_ERR( + inputs->at(i).node().ConnectTo(packed_ir_op_operand.node(), + graph::UnindexedTag{}, + graph::IndexedTag{})); + ADT_RETURN_IF_ERR(packed_ir_op_operand.node().ConnectTo( + packed_ir_op->node, + graph::IndexedTag{}, + graph::UnindexedTag{})); + } + for (int i = 0; i < outputs->size(); ++i) { + ADT_LET_CONST_REF(output_upstream_nodes, + outputs->at(i).node().UpstreamNodes()); + ADT_CHECK(output_upstream_nodes.size() == 0); + const auto& packed_ir_op_result = node_arena->New([&](const auto& node) { + return OptPackedIrOpResult{node, i}; + }); + ADT_RETURN_IF_ERR( + packed_ir_op->node.ConnectTo(packed_ir_op_result.node(), + graph::UnindexedTag{}, + graph::IndexedTag{})); + ADT_RETURN_IF_ERR(packed_ir_op_result.node().ConnectTo( + outputs->at(i).node(), + graph::IndexedTag{}, + graph::IndexedTag{})); + } + SetIrOpByUid(op_pattern_ctx, packed_ir_op->name, packed_ir_op); + return adt::Nothing{}; + } + + template + adt::Result GetIrOpByUid(const OpPtnCtxT& self, + const std::string& name) { + const auto& iter = self->uid2ir_op.find(name); + if (iter == self->uid2ir_op.end()) { + return adt::errors::AttributeError{std::string() + "no op named '" + + name + "' registered."}; + } + return iter->second; + } + + template + adt::Result CheckIrOpNameByUid(const OpPtnCtxT& self, + const std::string& name, + const IrOp& ir_op) { + ADT_LET_CONST_REF(existed_ir_op, GetIrOpByUid(self, name)); + ADT_CHECK(ir_op.op_name() == existed_ir_op.op_name()) + << adt::errors::TypeError{ + std::string() + "CheckIrOpNameByUid() failed. lhs: " + + ir_op.op_name() + ", rhs: " + existed_ir_op.op_name() + ""}; + return adt::Ok{}; + } + + template + bool HasIrOpByUid(const OpPtnCtxT& self, const std::string& name) { + return self->uid2ir_op.count(name) > 0; + } + + template + void SetIrOpByUid(const OpPtnCtxT& self, + const std::string& name, + const IrOp& ir_op) { + self->uid2ir_op[name] = ir_op; + } + + template + bool HasIrValueByUid(const TensorPtnCtxT& self, const std::string& name) { + return self->uid2ir_value.count(name); + } + + template + adt::Result GetIrValueByUid(const TensorPtnCtxT& self, + const std::string& name) { + const auto& iter = self->uid2ir_value.find(name); + if (iter == self->uid2ir_value.end()) { + return adt::errors::AttributeError{std::string() + "no tensor named '" + + name + "' registered."}; + } + return iter->second; + } + + template + void SetIrValueByUid(const TensorPtnCtxT& self, + const std::string& name, + const IrValue& ir_value) { + self->uid2ir_value[name] = ir_value; + } + + adt::Result> CloneIrValueDataAndRegister( + const TensorPtnCtx& self, + const NativeIrValue& native_ir_value) { + const auto& cloned_node = self->node_arena->New([&](const auto& node) { + return NativeIrValue{ + node, native_ir_value->name, self.shared_ptr()}; + }); + ADT_CHECK(cloned_node.template Has>()); + const auto& cloned = cloned_node.template Get>(); + SetIrValueByUid(self, native_ir_value->name, cloned); + return cloned; + } + + adt::Result> CloneIrValueDataAndRegister( + const TensorPtnCtx& self, + const PackedIrValue& packed_ir_value) { + const auto& cloned_node = self->node_arena->New([&](const auto& node) { + return PackedIrValue{node, packed_ir_value->name}; + }); + ADT_CHECK(cloned_node.template Has>()); + const auto& cloned = cloned_node.template Get>(); + SetIrValueByUid(self, packed_ir_value->name, cloned); + return cloned; + } + + adt::Result> GetNativeIrOpByUnboundNativeIrOp( + const UnboundNativeIrOp& ir_op) { + ADT_LET_CONST_REF(op_pattern_ctx, + adt::WeakPtrLock(ir_op->op_declare->op_pattern_ctx)); + const auto& node = op_pattern_ctx->node_arena->New([&](const auto& node) { + return NativeIrOp{node, ir_op->op_declare, ir_op->name}; + }); + ADT_CHECK(node.template Has>()); + return node.template Get>(); + } + + adt::Result> GetPackedIrOpByUnboundPackedIrOp( + const UnboundPackedIrOp& ir_op) { + ADT_LET_CONST_REF(op_pattern_ctx, + adt::WeakPtrLock(ir_op->op_declare->op_pattern_ctx)); + const auto& node = op_pattern_ctx->node_arena->New([&](const auto& node) { + return PackedIrOp{node, ir_op->op_declare, ir_op->name}; + }); + ADT_CHECK(node.template Has>()); + return node.template Get>(); + } + + adt::Result> GetOptPackedIrOpByUnboundOptPackedIrOp( + const UnboundOptPackedIrOp& ir_op) { + ADT_LET_CONST_REF(op_pattern_ctx, + adt::WeakPtrLock(ir_op->op_declare->op_pattern_ctx)); + const auto& node = op_pattern_ctx->node_arena->New([&](const auto& node) { + return OptPackedIrOp{node, ir_op->op_declare, ir_op->name}; + }); + return node.template TryGet>(); + } + + adt::Result> GetNativeIrValueByUnboundIrValue( + const UnboundIrValue& unbound_ir_value) { + ADT_LET_CONST_REF(tensor_ctx, + adt::WeakPtrLock(unbound_ir_value->tensor_pattern_ctx)); + if (HasIrValueByUid(tensor_ctx, unbound_ir_value->name)) { + ADT_LET_CONST_REF(ir_value, + GetIrValueByUid(tensor_ctx, unbound_ir_value->name)); + const auto& opt_ret = ir_value.Match( + [](const NativeIrValue& impl) + -> adt::Result> { return impl; }, + [&](const auto&) -> adt::Result> { + return adt::errors::RuntimeError{"only NativeIrValue supported."}; + }); + ADT_LET_CONST_REF(ret, opt_ret); + return ret; + } + const auto& node_arena = tensor_ctx->node_arena; + const auto& node = node_arena->New([&](const auto& node) { + return NativeIrValue{ + node, unbound_ir_value->name, unbound_ir_value->tensor_pattern_ctx}; + }); + ADT_CHECK(node.template Has>()); + const auto& native_ir_value = node.template Get>(); + SetIrValueByUid(tensor_ctx, native_ir_value->name, native_ir_value); + return native_ir_value; + } + + adt::Result> GetPackedIrValueByUnboundPackedIrValue( + const UnboundPackedIrValue& ir_value) { + ADT_LET_CONST_REF(tensor_ctx, + adt::WeakPtrLock(ir_value->tensor_pattern_ctx)); + if (HasIrValueByUid(tensor_ctx, ir_value->name)) { + ADT_LET_CONST_REF(ir_value, GetIrValueByUid(tensor_ctx, ir_value->name)); + const auto& opt_ret = ir_value.Match( + [](const PackedIrValue& impl) + -> adt::Result> { return impl; }, + [&](const auto&) -> adt::Result> { + return adt::errors::RuntimeError{"only PackedIrValue supported."}; + }); + ADT_LET_CONST_REF(ret, opt_ret); + return ret; + } + const auto& node_arena = tensor_ctx->node_arena; + const auto& node = node_arena->New([&](const auto& node) { + return PackedIrValue{node, ir_value->name}; + }); + ADT_CHECK(node.template Has>()); + const auto& packed_ir_value = node.template Get>(); + SetIrValueByUid(tensor_ctx, packed_ir_value->name, packed_ir_value); + return packed_ir_value; + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/opt_packed_ir_op.h b/paddle/ap/include/drr/opt_packed_ir_op.h new file mode 100644 index 00000000000000..a7fae212488b0e --- /dev/null +++ b/paddle/ap/include/drr/opt_packed_ir_op.h @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/opt_packed_ir_op_declare.h" +#include "paddle/ap/include/drr/type.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_topo_cstr.h" + +namespace ap::drr { + +template +struct OptPackedIrOpImpl { + graph::Node node; + OptPackedIrOpDeclare op_declare; + std::string name; + + bool operator==(const OptPackedIrOpImpl& other) const { + return this->node == other.node && this->op_declare == other.op_declare && + this->name == other.name; + } + + graph::OptPackedIrOpTopoCstr node_topo_cstr() const { + return graph::OptPackedIrOpTopoCstr{this->op_declare->op_name}; + } +}; + +template +ADT_DEFINE_RC(OptPackedIrOp, OptPackedIrOpImpl); + +axpr::TypeImpl> GetOptPackedIrOpClass(); + +template +struct Type> : public std::monostate { + using std::monostate::monostate; + + const char* Name() const { return "OptPackedIrOp"; } + + static axpr::TypeImpl> GetClass() { + return GetOptPackedIrOpClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/opt_packed_ir_op_declare.h b/paddle/ap/include/drr/opt_packed_ir_op_declare.h new file mode 100644 index 00000000000000..6a2703e3b6bc79 --- /dev/null +++ b/paddle/ap/include/drr/opt_packed_ir_op_declare.h @@ -0,0 +1,78 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/packed_ir_op_declare_data.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/type.h" + +namespace ap::drr { + +struct OpPatternCtxImpl; + +template +struct OptPackedIrOpDeclareImpl { + std::string op_name; + std::weak_ptr op_pattern_ctx; + std::optional> data; + + bool operator==(const OptPackedIrOpDeclareImpl& other) const { + return this->op_name == other.op_name && + this->op_pattern_ctx.lock() == other.op_pattern_ctx.lock(); + } + + template + adt::Result cast_data() const { + auto ThisToString = [&]() { + const void* address = static_cast(this); + std::ostringstream ss; + ss << address; + return ss.str(); + }; + ADT_CHECK(data.has_value()) << adt::errors::ValueError{ + std::string() + "((OptPackedIrOpDeclareImpl*)" + ThisToString() + + ")->data is nullopt"}; + ADT_CHECK(data.value().get() != nullptr) << adt::errors::ValueError{ + std::string() + "((OptPackedIrOpDeclareImpl*)" + ThisToString() + + ")->data.value() is nullptr"}; + auto* ptr = dynamic_cast(data.value().get()); + ADT_CHECK(data.value().get() != nullptr) << adt::errors::ValueError{ + std::string() + "((OptPackedIrOpDeclareImpl*)" + ThisToString() + + ")->data.value() cast to " + typeid(T).name() + " failed."}; + return ptr; + } +}; + +template +ADT_DEFINE_RC(OptPackedIrOpDeclare, OptPackedIrOpDeclareImpl); + +axpr::TypeImpl> +GetOptPackedIrOpDeclareClass(); + +template +struct Type> : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "OptPackedIrOpDeclare"; } + + static axpr::TypeImpl> GetClass() { + return GetOptPackedIrOpDeclareClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/opt_packed_ir_op_declare_method_class.h b/paddle/ap/include/drr/opt_packed_ir_op_declare_method_class.h new file mode 100644 index 00000000000000..c91238438a9384 --- /dev/null +++ b/paddle/ap/include/drr/opt_packed_ir_op_declare_method_class.h @@ -0,0 +1,29 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/opt_packed_ir_op_declare.h" +#include "paddle/ap/include/drr/tags.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetOptPackedIrOpDeclareClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/opt_packed_ir_op_method_class.h b/paddle/ap/include/drr/opt_packed_ir_op_method_class.h new file mode 100644 index 00000000000000..91cd57fa4b89d8 --- /dev/null +++ b/paddle/ap/include/drr/opt_packed_ir_op_method_class.h @@ -0,0 +1,28 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/opt_packed_ir_op.h" +#include "paddle/ap/include/drr/tags.h" + +namespace ap::drr { + +axpr::TypeImpl> GetOptPackedIrOpClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/opt_packed_ir_op_operand.h b/paddle/ap/include/drr/opt_packed_ir_op_operand.h new file mode 100644 index 00000000000000..0b7d6244792436 --- /dev/null +++ b/paddle/ap/include/drr/opt_packed_ir_op_operand.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_topo_cstr.h" + +namespace ap::drr { + +template +struct OptPackedIrOpOperandImpl { + graph::Node node; + std::size_t local_uid; // not a index + + bool operator==(const OptPackedIrOpOperandImpl& other) const { + return this->node == other.node && this->local_uid == other.local_uid; + } + + graph::OptPackedIrOpOperandTopoCstr node_topo_cstr() const { + return graph::OptPackedIrOpOperandTopoCstr{}; + } +}; + +template +ADT_DEFINE_RC(OptPackedIrOpOperand, OptPackedIrOpOperandImpl); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/opt_packed_ir_op_result.h b/paddle/ap/include/drr/opt_packed_ir_op_result.h new file mode 100644 index 00000000000000..f1797a24e6110c --- /dev/null +++ b/paddle/ap/include/drr/opt_packed_ir_op_result.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_topo_cstr.h" + +namespace ap::drr { + +template +struct OptPackedIrOpResultImpl { + graph::Node node; + std::size_t local_uid; // not a index + + bool operator==(const OptPackedIrOpResultImpl& other) const { + return this->node == other.node && this->local_uid == other.local_uid; + } + + graph::OptPackedIrOpResultTopoCstr node_topo_cstr() const { + return graph::OptPackedIrOpResultTopoCstr{}; + } +}; + +template +ADT_DEFINE_RC(OptPackedIrOpResult, OptPackedIrOpResultImpl); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/packed_ir_op.h b/paddle/ap/include/drr/packed_ir_op.h new file mode 100644 index 00000000000000..ec6152b7771112 --- /dev/null +++ b/paddle/ap/include/drr/packed_ir_op.h @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/packed_ir_op_declare.h" +#include "paddle/ap/include/drr/type.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_topo_cstr.h" + +namespace ap::drr { + +template +struct PackedIrOpImpl { + graph::Node node; + PackedIrOpDeclare op_declare; + std::string name; + + bool operator==(const PackedIrOpImpl& other) const { + return this->node == other.node && this->op_declare == other.op_declare && + this->name == other.name; + } + + graph::PackedIrOpTopoCstr node_topo_cstr() const { + return graph::PackedIrOpTopoCstr{this->op_declare->op_name}; + } +}; + +template +ADT_DEFINE_RC(PackedIrOp, PackedIrOpImpl); + +axpr::TypeImpl> GetPackedIrOpClass(); + +template +struct Type> : public std::monostate { + using std::monostate::monostate; + + const char* Name() const { return "PackedIrOp"; } + + static axpr::TypeImpl> GetClass() { + return GetPackedIrOpClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/packed_ir_op_declare.h b/paddle/ap/include/drr/packed_ir_op_declare.h new file mode 100644 index 00000000000000..05a9af558370a1 --- /dev/null +++ b/paddle/ap/include/drr/packed_ir_op_declare.h @@ -0,0 +1,93 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/packed_ir_op_declare_data.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/type.h" + +namespace ap::drr { + +struct OpPatternCtxImpl; + +template +struct PackedIrOpDeclareImpl { + std::string op_name; + std::weak_ptr op_pattern_ctx; + std::optional> data; + + bool operator==(const PackedIrOpDeclareImpl& other) const { + return this->op_name == other.op_name && + this->op_pattern_ctx.lock() == other.op_pattern_ctx.lock(); + } + + template + adt::Result cast_data() const { + auto ThisToString = [&]() { + const void* address = static_cast(this); + std::ostringstream ss; + ss << address; + return ss.str(); + }; + ADT_CHECK(data.has_value()) + << adt::errors::ValueError{std::string() + "((PackedIrOpDeclareImpl*)" + + ThisToString() + ")->data is nullopt"}; + ADT_CHECK(data.value().get() != nullptr) << adt::errors::ValueError{ + std::string() + "((PackedIrOpDeclareImpl*)" + ThisToString() + + ")->data.value() is nullptr"}; + auto* ptr = dynamic_cast(data.value().get()); + ADT_CHECK(data.value().get() != nullptr) << adt::errors::ValueError{ + std::string() + "((PackedIrOpDeclareImpl*)" + ThisToString() + + ")->data.value() cast to " + typeid(T).name() + " failed."}; + return ptr; + } +}; + +template +ADT_DEFINE_RC(PackedIrOpDeclare, PackedIrOpDeclareImpl); + +axpr::TypeImpl> +GetSrcPtnPackedIrOpDeclareClass(); + +template +struct Type>> + : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "SrcPtnPackedIrOpDeclare"; } + + static axpr::TypeImpl> GetClass() { + return GetSrcPtnPackedIrOpDeclareClass(); + } +}; + +axpr::TypeImpl> +GetResPtnPackedIrOpDeclareClass(); + +template +struct Type>> + : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "ResPtnPackedIrOpDeclare"; } + + static axpr::TypeImpl> GetClass() { + return GetResPtnPackedIrOpDeclareClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/packed_ir_op_declare_data.h b/paddle/ap/include/drr/packed_ir_op_declare_data.h new file mode 100644 index 00000000000000..4a01d6c153e598 --- /dev/null +++ b/paddle/ap/include/drr/packed_ir_op_declare_data.h @@ -0,0 +1,31 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/tags.h" + +namespace ap::drr { + +class PackedIrOpDeclareData { + protected: + PackedIrOpDeclareData() {} + PackedIrOpDeclareData(const PackedIrOpDeclareData&) = default; + PackedIrOpDeclareData(PackedIrOpDeclareData&&) = default; + virtual ~PackedIrOpDeclareData() {} +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/packed_ir_op_declare_method_class.h b/paddle/ap/include/drr/packed_ir_op_declare_method_class.h new file mode 100644 index 00000000000000..98107194618bd8 --- /dev/null +++ b/paddle/ap/include/drr/packed_ir_op_declare_method_class.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/packed_ir_op_declare.h" +#include "paddle/ap/include/drr/tags.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetSrcPtnPackedIrOpDeclareClass(); + +axpr::TypeImpl> +GetResPtnPackedIrOpDeclareClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/packed_ir_op_method_class.h b/paddle/ap/include/drr/packed_ir_op_method_class.h new file mode 100644 index 00000000000000..09fd8e895ebd4d --- /dev/null +++ b/paddle/ap/include/drr/packed_ir_op_method_class.h @@ -0,0 +1,28 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/packed_ir_op.h" +#include "paddle/ap/include/drr/tags.h" + +namespace ap::drr { + +axpr::TypeImpl> GetPackedIrOpClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/packed_ir_op_operand.h b/paddle/ap/include/drr/packed_ir_op_operand.h new file mode 100644 index 00000000000000..2b9863fb12285e --- /dev/null +++ b/paddle/ap/include/drr/packed_ir_op_operand.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_topo_cstr.h" + +namespace ap::drr { + +template +struct PackedIrOpOperandImpl { + graph::Node node; + std::size_t local_uid; // not a index + + bool operator==(const PackedIrOpOperandImpl& other) const { + return this->node == other.node && this->local_uid == other.local_uid; + } + + graph::PackedIrOpOperandTopoCstr node_topo_cstr() const { + return graph::PackedIrOpOperandTopoCstr{}; + } +}; + +template +ADT_DEFINE_RC(PackedIrOpOperand, PackedIrOpOperandImpl); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/packed_ir_op_result.h b/paddle/ap/include/drr/packed_ir_op_result.h new file mode 100644 index 00000000000000..5a0686f9e48639 --- /dev/null +++ b/paddle/ap/include/drr/packed_ir_op_result.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_topo_cstr.h" + +namespace ap::drr { + +template +struct PackedIrOpResultImpl { + graph::Node node; + std::size_t local_uid; // not a index + + bool operator==(const PackedIrOpResultImpl& other) const { + return this->node == other.node && this->local_uid == other.local_uid; + } + + graph::PackedIrOpResultTopoCstr node_topo_cstr() const { + return graph::PackedIrOpResultTopoCstr{}; + } +}; + +template +ADT_DEFINE_RC(PackedIrOpResult, PackedIrOpResultImpl); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/packed_ir_value.h b/paddle/ap/include/drr/packed_ir_value.h new file mode 100644 index 00000000000000..b69ec12206acb9 --- /dev/null +++ b/paddle/ap/include/drr/packed_ir_value.h @@ -0,0 +1,97 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/type.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_topo_cstr.h" + +namespace ap::drr { + +template +struct PackedIrValueImpl { + graph::Node node; + std::string name; + + bool operator==(const PackedIrValueImpl& other) const { + return this->node == other.node && this->name == other.name; + } + + graph::PackedIrValueTopoCstr node_topo_cstr() const { + return graph::PackedIrValueTopoCstr{}; + } +}; + +template +ADT_DEFINE_RC(PackedIrValue, PackedIrValueImpl); + +axpr::TypeImpl> +GetSrcPtnPackedIrValueClass(); + +template +struct Type>> : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "SrcPtnPackedIrValue"; } + + static axpr::TypeImpl> GetClass() { + return GetSrcPtnPackedIrValueClass(); + } +}; + +axpr::TypeImpl> +GetStarredSrcPtnPackedIrValueClass(); + +template +struct Type>>> + : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "StarredSrcPtnPackedIrValue"; } + + static axpr::TypeImpl> GetClass() { + return GetStarredSrcPtnPackedIrValueClass(); + } +}; + +axpr::TypeImpl> +GetResPtnPackedIrValueClass(); + +template +struct Type>> : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "ResPtnPackedIrValue"; } + + static axpr::TypeImpl> GetClass() { + return GetResPtnPackedIrValueClass(); + } +}; + +axpr::TypeImpl> +GetStarredResPtnPackedIrValueClass(); + +template +struct Type>>> + : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "StarredResPtnPackedIrValue"; } + + static axpr::TypeImpl> GetClass() { + return GetStarredResPtnPackedIrValueClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/packed_ir_value_method_class.h b/paddle/ap/include/drr/packed_ir_value_method_class.h new file mode 100644 index 00000000000000..efa80880a5cd1a --- /dev/null +++ b/paddle/ap/include/drr/packed_ir_value_method_class.h @@ -0,0 +1,35 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/packed_ir_value.h" +#include "paddle/ap/include/drr/tags.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetSrcPtnPackedIrValueClass(); +axpr::TypeImpl> +GetStarredSrcPtnPackedIrValueClass(); +axpr::TypeImpl> +GetResPtnPackedIrValueClass(); +axpr::TypeImpl> +GetStarredResPtnPackedIrValueClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/res_ptn_op_pattern_ctx_method_class.h b/paddle/ap/include/drr/res_ptn_op_pattern_ctx_method_class.h new file mode 100644 index 00000000000000..8344a02c4db743 --- /dev/null +++ b/paddle/ap/include/drr/res_ptn_op_pattern_ctx_method_class.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/drr_value_helper.h" +#include "paddle/ap/include/drr/native_ir_op_declare.h" +#include "paddle/ap/include/drr/op_pattern_ctx.h" +#include "paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h" +#include "paddle/ap/include/drr/packed_ir_op_declare.h" +#include "paddle/ap/include/drr/res_ptn_packed_ir_op_declare_data.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/unbound_native_ir_op.h" +#include "paddle/ap/include/drr/unbound_packed_ir_op.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetResPtnOpPatternCtxClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/res_ptn_packed_ir_op_declare_data.h b/paddle/ap/include/drr/res_ptn_packed_ir_op_declare_data.h new file mode 100644 index 00000000000000..33aafaf0589ed0 --- /dev/null +++ b/paddle/ap/include/drr/res_ptn_packed_ir_op_declare_data.h @@ -0,0 +1,35 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/drr/packed_ir_op_declare_data.h" + +namespace ap::drr { + +class ResPtnPackedIrOpDeclareData : public PackedIrOpDeclareData { + public: + explicit ResPtnPackedIrOpDeclareData(const axpr::Value& code_gen_func) + : PackedIrOpDeclareData(), code_gen_func_(code_gen_func) {} + + const axpr::Value& code_gen_func() const { return code_gen_func_; } + + private: + axpr::Value code_gen_func_; +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/res_ptn_tensor_pattern_ctx_method_class.h b/paddle/ap/include/drr/res_ptn_tensor_pattern_ctx_method_class.h new file mode 100644 index 00000000000000..d225357620408a --- /dev/null +++ b/paddle/ap/include/drr/res_ptn_tensor_pattern_ctx_method_class.h @@ -0,0 +1,31 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/drr_value_helper.h" +#include "paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h" +#include "paddle/ap/include/drr/source_pattern_ctx.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/tensor_pattern_ctx.h" + +namespace ap::drr { +axpr::TypeImpl> +GetResPtnTensorPatternCtxClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/res_ptn_unbound_native_ir_op_method_class.h b/paddle/ap/include/drr/res_ptn_unbound_native_ir_op_method_class.h new file mode 100644 index 00000000000000..3bdc951cd4da47 --- /dev/null +++ b/paddle/ap/include/drr/res_ptn_unbound_native_ir_op_method_class.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/native_ir_value.h" +#include "paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h" +#include "paddle/ap/include/drr/res_ptn_valid_out_ir_value.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/unbound_ir_value.h" +#include "paddle/ap/include/drr/unbound_native_ir_op.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetResPtnUnboundNativeIrOpClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/res_ptn_unbound_packed_ir_op_method_class.h b/paddle/ap/include/drr/res_ptn_unbound_packed_ir_op_method_class.h new file mode 100644 index 00000000000000..90faf2ff6de6eb --- /dev/null +++ b/paddle/ap/include/drr/res_ptn_unbound_packed_ir_op_method_class.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/drr_value_helper.h" +#include "paddle/ap/include/drr/ir_value.h" +#include "paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h" +#include "paddle/ap/include/drr/packed_ir_value.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/unbound_ir_value.h" +#include "paddle/ap/include/drr/unbound_packed_ir_op.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetResPtnUnboundPackedIrOpClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/res_ptn_valid_out_ir_value.h b/paddle/ap/include/drr/res_ptn_valid_out_ir_value.h new file mode 100644 index 00000000000000..848f553828a7c5 --- /dev/null +++ b/paddle/ap/include/drr/res_ptn_valid_out_ir_value.h @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/native_ir_value.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/unbound_ir_value.h" + +namespace ap::drr { + +using ResPtnValidOutIrValueImpl = + std::variant, tResPtn>>; + +struct ResPtnValidOutIrValue : public ResPtnValidOutIrValueImpl { + using ResPtnValidOutIrValueImpl::ResPtnValidOutIrValueImpl; + + ADT_DEFINE_VARIANT_METHODS(ResPtnValidOutIrValueImpl); + + static adt::Result CastFromAxprValue( + const axpr::Value& val) { + if (val.template CastableTo>()) { + ADT_LET_CONST_REF(ret, val.template CastTo>()); + return ret; + } + if (val.template CastableTo>>()) { + ADT_LET_CONST_REF( + ret, val.template CastTo>>()); + return ret; + } + return adt::errors::TypeError{ + "ResPtnValidOutIrValue::CastFromAxprValue() failed"}; + } + + const std::string& name() const { + return Match([](const tResPtn>& ir_value) + -> const std::string& { return ir_value.value()->name; }, + [](const auto& ir_value) -> const std::string& { + return ir_value->name; + }); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/result_pattern_ctx.h b/paddle/ap/include/drr/result_pattern_ctx.h new file mode 100644 index 00000000000000..73f0830a5c271a --- /dev/null +++ b/paddle/ap/include/drr/result_pattern_ctx.h @@ -0,0 +1,55 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/op_pattern_ctx.h" +#include "paddle/ap/include/drr/source_pattern_ctx.h" +#include "paddle/ap/include/drr/tensor_pattern_ctx.h" +#include "paddle/ap/include/drr/type.h" +#include "paddle/ap/include/graph/node_arena.h" +#include "paddle/ap/include/graph/tags.h" + +namespace ap::drr { + +struct ResultPatternCtxImpl { + std::shared_ptr> node_arena; + OpPatternCtx op_pattern_ctx; + TensorPatternCtx tensor_pattern_ctx; + SourcePatternCtx source_pattern_ctx; + std::unordered_set internal_native_ir_value_names; + + bool operator==(const ResultPatternCtxImpl& other) const { + return this != &other; + } +}; + +ADT_DEFINE_RC(ResultPatternCtx, ResultPatternCtxImpl); + +axpr::TypeImpl> +GetResultPatternCtxClass(); + +template <> +struct Type : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "ResultPatternCtx"; } + + static axpr::TypeImpl> GetClass() { + return GetResultPatternCtxClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/result_pattern_ctx_method_class.h b/paddle/ap/include/drr/result_pattern_ctx_method_class.h new file mode 100644 index 00000000000000..6eba184321d742 --- /dev/null +++ b/paddle/ap/include/drr/result_pattern_ctx_method_class.h @@ -0,0 +1,28 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/result_pattern_ctx.h" +#include "paddle/ap/include/drr/tags.h" + +namespace ap::drr { +axpr::TypeImpl> +GetResultPatternCtxClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/result_pattern_helper.h b/paddle/ap/include/drr/result_pattern_helper.h new file mode 100644 index 00000000000000..9eeb937861a0c4 --- /dev/null +++ b/paddle/ap/include/drr/result_pattern_helper.h @@ -0,0 +1,134 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/atomic.h" +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/drr/drr_graph_descriptor.h" +#include "paddle/ap/include/drr/drr_node_descriptor.h" +#include "paddle/ap/include/drr/res_ptn_packed_ir_op_declare_data.h" +#include "paddle/ap/include/drr/value.h" + +namespace ap::drr { + +struct ResultPatternHelper { + using DrrNode = drr::Node; + using DrrGraphNode = graph::Node; + + using DrrNativeIrValue = drr::NativeIrValue; + using DrrPackedIrValue = drr::PackedIrValue; + using DrrIrValue = drr::IrValue; + + using DrrNativeIrOp = drr::NativeIrOp; + using DrrNativeIrOpOperand = drr::NativeIrOpOperand; + using DrrNativeIrOpResult = drr::NativeIrOpResult; + using DrrPackedIrOp = drr::PackedIrOp; + using DrrPackedIrOpOperand = drr::PackedIrOpOperand; + using DrrPackedIrOpResult = drr::PackedIrOpResult; + using DrrOptPackedIrOp = drr::OptPackedIrOp; + using DrrOptPackedIrOpOperand = drr::OptPackedIrOpOperand; + using DrrOptPackedIrOpResult = drr::OptPackedIrOpResult; + + const DrrCtx& drr_ctx; + + template + adt::Result VisitResPtnInputIrValueByResPtnIrOp( + const DrrPackedIrOp& res_ptn_ir_op, const DoEachT& DoEach) const { + return VisitResPtnInputIrValueByResPtnIrOpImpl(res_ptn_ir_op, DoEach); + } + + template + adt::Result VisitResPtnInputIrValueByResPtnIrOp( + const DrrNativeIrOp& res_ptn_ir_op, const DoEachT& DoEach) const { + return VisitResPtnInputIrValueByResPtnIrOpImpl(res_ptn_ir_op, DoEach); + } + + template + adt::Result VisitResPtnInputIrValueByResPtnIrOpImpl( + const IrOpT& res_ptn_ir_op, const DoEachT& DoEach) const { + auto VisitOpOperand = + [&](const DrrGraphNode& op_operand) -> adt::Result { + ADT_LET_CONST_REF(op_operand_downstreams, op_operand.UpstreamNodes()); + ADT_LET_CONST_REF(ir_value_node, op_operand_downstreams.Sole()); + ADT_LET_CONST_REF(ir_value, ir_value_node.Get()); + const auto& opt_drr_ir_value = DrrIrValue::OptCastFrom(ir_value); + ADT_CHECK(opt_drr_ir_value.has_value()); + const auto& drr_ir_value = opt_drr_ir_value.value(); + return DoEach(drr_ir_value); + }; + ADT_LET_CONST_REF(upstreams, res_ptn_ir_op->node.UpstreamNodes()); + ADT_RETURN_IF_ERR(upstreams.VisitNodes(VisitOpOperand)); + return adt::Ok{}; + } + + template + adt::Result VisitResPtnOutputIrValueByResPtnIrOp( + const DrrPackedIrOp& res_ptn_ir_op, const DoEachT& DoEach) const { + return VisitResPtnOutputIrValueByResPtnIrOpImpl(res_ptn_ir_op, DoEach); + } + + template + adt::Result VisitResPtnOutputIrValueByResPtnIrOp( + const DrrNativeIrOp& res_ptn_ir_op, const DoEachT& DoEach) const { + return VisitResPtnOutputIrValueByResPtnIrOpImpl(res_ptn_ir_op, DoEach); + } + + template + adt::Result VisitResPtnOutputIrValueByResPtnIrOpImpl( + const IrOpT& res_ptn_ir_op, const DoEachT& DoEach) const { + auto VisitOpResult = + [&](const DrrGraphNode& op_result) -> adt::Result { + ADT_LET_CONST_REF(op_result_downstreams, op_result.DownstreamNodes()); + ADT_LET_CONST_REF(ir_node, op_result_downstreams.Sole()); + ADT_LET_CONST_REF(drr_ir_node, ir_node.Get()); + const auto& opt_drr_ir_value = DrrIrValue::OptCastFrom(drr_ir_node); + ADT_CHECK(opt_drr_ir_value.has_value()); + const auto& drr_ir_value = opt_drr_ir_value.value(); + return DoEach(drr_ir_value); + }; + ADT_LET_CONST_REF(downstreams, res_ptn_ir_op->node.DownstreamNodes()); + ADT_RETURN_IF_ERR(downstreams.VisitNodes(VisitOpResult)); + return adt::Ok{}; + } + + std::optional SrcPtnIrValue4ResPtnIrValue( + const DrrIrValue& res_ptn_ir_value) const { + const auto& opt_src_ptn_ctx = drr_ctx->GetSourcePatternCtx(); + if (opt_src_ptn_ctx.HasError()) { + return std::nullopt; + } + const auto& src_ptn_ctx = opt_src_ptn_ctx.GetOkValue(); + const auto& map = src_ptn_ctx->tensor_pattern_ctx->uid2ir_value; + auto GetSrcPtnIrValue = + [&](const auto& ir_value) -> std::optional { + const auto iter = map.find(ir_value->name); + if (iter == map.end()) { + return std::nullopt; + } + return iter->second; + }; + return res_ptn_ir_value.Match( + [&](const DrrNativeIrValue& ir_value) -> std::optional { + return GetSrcPtnIrValue(ir_value); + }, + [&](const DrrPackedIrValue& ir_value) -> std::optional { + return GetSrcPtnIrValue(ir_value); + }); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/source_pattern_ctx.h b/paddle/ap/include/drr/source_pattern_ctx.h new file mode 100644 index 00000000000000..4ba0d663c736fd --- /dev/null +++ b/paddle/ap/include/drr/source_pattern_ctx.h @@ -0,0 +1,51 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/op_pattern_ctx.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/tensor_pattern_ctx.h" +#include "paddle/ap/include/drr/type.h" +#include "paddle/ap/include/graph/node_arena.h" + +namespace ap::drr { + +struct SourcePatternCtxImpl { + std::shared_ptr> node_arena; + OpPatternCtx op_pattern_ctx; + TensorPatternCtx tensor_pattern_ctx; + + bool operator==(const SourcePatternCtxImpl& other) const { + return this != &other; + } +}; + +ADT_DEFINE_RC(SourcePatternCtx, SourcePatternCtxImpl); + +axpr::TypeImpl> +GetSourcePatternCtxClass(); + +template <> +struct Type : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "SourcePatternCtx"; } + + static axpr::TypeImpl> GetClass() { + return GetSourcePatternCtxClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/source_pattern_ctx_method_class.h b/paddle/ap/include/drr/source_pattern_ctx_method_class.h new file mode 100644 index 00000000000000..823c4d9530e1f6 --- /dev/null +++ b/paddle/ap/include/drr/source_pattern_ctx_method_class.h @@ -0,0 +1,29 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/source_pattern_ctx.h" +#include "paddle/ap/include/drr/tags.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetSourcePatternCtxClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/src_ptn_op_pattern_ctx_method_class.h b/paddle/ap/include/drr/src_ptn_op_pattern_ctx_method_class.h new file mode 100644 index 00000000000000..234f1aa4661256 --- /dev/null +++ b/paddle/ap/include/drr/src_ptn_op_pattern_ctx_method_class.h @@ -0,0 +1,38 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" + +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/drr_value_helper.h" +#include "paddle/ap/include/drr/native_ir_op_declare.h" +#include "paddle/ap/include/drr/op_pattern_ctx.h" +#include "paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h" +#include "paddle/ap/include/drr/packed_ir_op_declare.h" +#include "paddle/ap/include/drr/src_ptn_packed_ir_op_declare_data.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/unbound_native_ir_op.h" +#include "paddle/ap/include/drr/unbound_packed_ir_op.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetSrcPtnOpPatternCtxClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/src_ptn_packed_ir_op_declare_data.h b/paddle/ap/include/drr/src_ptn_packed_ir_op_declare_data.h new file mode 100644 index 00000000000000..f5f0fdda7c8ecc --- /dev/null +++ b/paddle/ap/include/drr/src_ptn_packed_ir_op_declare_data.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/drr/packed_ir_op_declare_data.h" +#include "paddle/ap/include/drr/source_pattern_ctx.h" + +namespace ap::drr { + +struct SrcPtnPackedIrOpDeclareData : public PackedIrOpDeclareData { + SrcPtnPackedIrOpDeclareData() : PackedIrOpDeclareData() {} + + std::optional> + inner_source_pattern_func; + + std::optional inner_source_pattern_ctx; +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/src_ptn_tensor_pattern_ctx_method_class.h b/paddle/ap/include/drr/src_ptn_tensor_pattern_ctx_method_class.h new file mode 100644 index 00000000000000..7eb6edb890a4c3 --- /dev/null +++ b/paddle/ap/include/drr/src_ptn_tensor_pattern_ctx_method_class.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/drr_value_helper.h" +#include "paddle/ap/include/drr/ir_value.h" +#include "paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/tensor_pattern_ctx.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetSrcPtnTensorPatternCtxClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/src_ptn_unbound_native_ir_op_method_class.h b/paddle/ap/include/drr/src_ptn_unbound_native_ir_op_method_class.h new file mode 100644 index 00000000000000..784439277ee348 --- /dev/null +++ b/paddle/ap/include/drr/src_ptn_unbound_native_ir_op_method_class.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/drr_value_helper.h" +#include "paddle/ap/include/drr/native_ir_value.h" +#include "paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/unbound_ir_value.h" +#include "paddle/ap/include/drr/unbound_native_ir_op.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetSrcPtnUnboundNativeIrOpClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/src_ptn_unbound_packed_ir_op_method_class.h b/paddle/ap/include/drr/src_ptn_unbound_packed_ir_op_method_class.h new file mode 100644 index 00000000000000..9b353e0835ffab --- /dev/null +++ b/paddle/ap/include/drr/src_ptn_unbound_packed_ir_op_method_class.h @@ -0,0 +1,38 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/interpreter_base.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value_helper.h" +#include "paddle/ap/include/drr/ir_op.h" +#include "paddle/ap/include/drr/ir_value.h" +#include "paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h" +#include "paddle/ap/include/drr/packed_ir_value.h" +#include "paddle/ap/include/drr/src_ptn_packed_ir_op_declare_data.h" +#include "paddle/ap/include/drr/src_ptn_valid_in_ir_value.h" +#include "paddle/ap/include/drr/src_ptn_valid_out_ir_value.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/unbound_ir_value.h" +#include "paddle/ap/include/drr/unbound_packed_ir_op.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetSrcPtnUnboundPackedIrOpClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/src_ptn_valid_in_ir_value.h b/paddle/ap/include/drr/src_ptn_valid_in_ir_value.h new file mode 100644 index 00000000000000..d924e31aeaae61 --- /dev/null +++ b/paddle/ap/include/drr/src_ptn_valid_in_ir_value.h @@ -0,0 +1,43 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/packed_ir_value.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/unbound_ir_value.h" +#include "paddle/ap/include/drr/unbound_packed_ir_op.h" + +namespace ap::drr { + +using SrcPtnValidInIrValueImpl = std::variant, + NativeIrValue, + UnboundIrValue, + UnboundPackedIrValue>; + +struct SrcPtnValidInIrValue : public SrcPtnValidInIrValueImpl { + using SrcPtnValidInIrValueImpl::SrcPtnValidInIrValueImpl; + + ADT_DEFINE_VARIANT_METHODS(SrcPtnValidInIrValueImpl); + + const std::string& name() const { + return Match([](const auto& ir_value) -> const std::string& { + return ir_value->name; + }); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/src_ptn_valid_out_ir_value.h b/paddle/ap/include/drr/src_ptn_valid_out_ir_value.h new file mode 100644 index 00000000000000..86e98eb9438c04 --- /dev/null +++ b/paddle/ap/include/drr/src_ptn_valid_out_ir_value.h @@ -0,0 +1,41 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/packed_ir_value.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/unbound_ir_value.h" +#include "paddle/ap/include/drr/unbound_packed_ir_op.h" + +namespace ap::drr { + +using SrcPtnValidOutIrValueImpl = + std::variant, UnboundPackedIrValue>; + +struct SrcPtnValidOutIrValue : public SrcPtnValidOutIrValueImpl { + using SrcPtnValidOutIrValueImpl::SrcPtnValidOutIrValueImpl; + + ADT_DEFINE_VARIANT_METHODS(SrcPtnValidOutIrValueImpl); + + const std::string& name() const { + return Match([](const auto& ir_value) -> const std::string& { + return ir_value->name; + }); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/tags.h b/paddle/ap/include/drr/tags.h new file mode 100644 index 00000000000000..33dde31e379e2f --- /dev/null +++ b/paddle/ap/include/drr/tags.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" + +namespace ap::drr { + +// starred +ADT_DEFINE_TAG(tStarred); + +// source pattern +ADT_DEFINE_TAG(tSrcPtn); + +// result pattern +ADT_DEFINE_TAG(tResPtn); + +template +tSrcPtn SrcPtn(const T& value) { + return tSrcPtn{value}; +} + +template +tResPtn ResPtn(const T& value) { + return tResPtn{value}; +} + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/tensor_pattern_ctx.h b/paddle/ap/include/drr/tensor_pattern_ctx.h new file mode 100644 index 00000000000000..f6bd50ea944e0e --- /dev/null +++ b/paddle/ap/include/drr/tensor_pattern_ctx.h @@ -0,0 +1,70 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/ir_value.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/type.h" +#include "paddle/ap/include/graph/node_arena.h" +#include "paddle/ap/include/graph/tags.h" + +namespace ap::drr { + +struct DrrCtxImpl; + +struct TensorPatternCtxImpl { + std::shared_ptr> node_arena; + mutable std::map uid2ir_value; + std::weak_ptr drr_ctx; + mutable std::map uid2type; + + bool operator==(const TensorPatternCtxImpl& other) const { + return this == &other; + } +}; + +ADT_DEFINE_RC(TensorPatternCtx, TensorPatternCtxImpl); + +axpr::TypeImpl> +GetSrcPtnTensorPatternCtxClass(); + +template <> +struct Type> : public std::monostate { + using value_type = drr::tSrcPtn; + + const char* Name() const { return "SrcPtnTensorPatternCtx"; } + + static axpr::TypeImpl> GetClass() { + return GetSrcPtnTensorPatternCtxClass(); + } +}; + +axpr::TypeImpl> +GetResPtnTensorPatternCtxClass(); + +template <> +struct Type> : public std::monostate { + using value_type = drr::tResPtn; + + const char* Name() const { return "ResPtnTensorPatternCtx"; } + + static axpr::TypeImpl> GetClass() { + return GetResPtnTensorPatternCtxClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/topo_kind.h b/paddle/ap/include/drr/topo_kind.h new file mode 100644 index 00000000000000..ff49aad02e3ae1 --- /dev/null +++ b/paddle/ap/include/drr/topo_kind.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace ap::drr::topo_kind { + +struct Default; + +// graph of all OpOperand and OpResult relationship. +struct AllOperandAndResult; + +// graph of native OpOperand and OpResult relationship. +struct NativeOperandAndResult; + +// graph with augmented reference value/op_operand/op/op_result. +struct RefAugmented; + +// bound to owner block + +struct BlockBound; + +} // namespace ap::drr::topo_kind diff --git a/paddle/ap/include/drr/type.h b/paddle/ap/include/drr/type.h new file mode 100644 index 00000000000000..0b00ecc3b8c8b5 --- /dev/null +++ b/paddle/ap/include/drr/type.h @@ -0,0 +1,22 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace ap::drr { + +template +struct Type; + +} diff --git a/paddle/ap/include/drr/unbound_ir_value.h b/paddle/ap/include/drr/unbound_ir_value.h new file mode 100644 index 00000000000000..6826d7b131787f --- /dev/null +++ b/paddle/ap/include/drr/unbound_ir_value.h @@ -0,0 +1,53 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/type.h" + +namespace ap::drr { + +struct TensorPatternCtxImpl; + +template +struct UnboundIrValueImpl { + std::string name; + std::weak_ptr tensor_pattern_ctx; + + bool operator==(const UnboundIrValueImpl& other) const { + return this->name == other.name && + this->tensor_pattern_ctx.lock() == other.tensor_pattern_ctx.lock(); + } +}; + +template +ADT_DEFINE_RC(UnboundIrValue, UnboundIrValueImpl); + +axpr::TypeImpl> +GetUnboundIrValueClass(); + +template +struct Type> : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "UnboundIrValue"; } + + static axpr::TypeImpl> GetClass() { + return GetUnboundIrValueClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/unbound_ir_value_method_class.h b/paddle/ap/include/drr/unbound_ir_value_method_class.h new file mode 100644 index 00000000000000..aa7f8298206e5c --- /dev/null +++ b/paddle/ap/include/drr/unbound_ir_value_method_class.h @@ -0,0 +1,33 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/drr_value_helper.h" +#include "paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/unbound_ir_value.h" +#include "paddle/ap/include/drr/unbound_packed_ir_value.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetUnboundIrValueClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/unbound_native_ir_op.h b/paddle/ap/include/drr/unbound_native_ir_op.h new file mode 100644 index 00000000000000..bc39db8efd4954 --- /dev/null +++ b/paddle/ap/include/drr/unbound_native_ir_op.h @@ -0,0 +1,68 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/native_ir_op_declare.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/type.h" +#include "paddle/ap/include/graph/node.h" + +namespace ap::drr { + +template +struct UnboundNativeIrOpImpl { + NativeIrOpDeclare op_declare; + std::string name; + + bool operator==(const UnboundNativeIrOpImpl& other) const { + return this->op_declare == other.op_declare && this->name == other.name; + } +}; + +template +ADT_DEFINE_RC(UnboundNativeIrOp, UnboundNativeIrOpImpl); + +axpr::TypeImpl> +GetSrcPtnUnboundNativeIrOpClass(); + +template +struct Type>> + : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "SrcPtnUnboundNativeIrOp"; } + + static axpr::TypeImpl> GetClass() { + return GetSrcPtnUnboundNativeIrOpClass(); + } +}; + +axpr::TypeImpl> +GetResPtnUnboundNativeIrOpClass(); + +template +struct Type>> + : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "ResPtnUnboundNativeIrOp"; } + + static axpr::TypeImpl> GetClass() { + return GetResPtnUnboundNativeIrOpClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/unbound_opt_packed_ir_op.h b/paddle/ap/include/drr/unbound_opt_packed_ir_op.h new file mode 100644 index 00000000000000..d37e6cba6e474f --- /dev/null +++ b/paddle/ap/include/drr/unbound_opt_packed_ir_op.h @@ -0,0 +1,53 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/opt_packed_ir_op_declare.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/type.h" +#include "paddle/ap/include/graph/node.h" + +namespace ap::drr { + +template +struct UnboundOptPackedIrOpImpl { + public: + OptPackedIrOpDeclare op_declare; + std::string name; + bool operator==(const UnboundOptPackedIrOpImpl& other) const { + return this->op_declare == other.op_declare && this->name == other.name; + } +}; + +template +ADT_DEFINE_RC(UnboundOptPackedIrOp, UnboundOptPackedIrOpImpl); + +axpr::TypeImpl> +GetUnboundOptPackedIrOpClass(); + +template +struct Type> : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "UnboundOptPackedIrOp"; } + + static axpr::TypeImpl> GetClass() { + return GetUnboundOptPackedIrOpClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/unbound_opt_packed_ir_op_method_class.h b/paddle/ap/include/drr/unbound_opt_packed_ir_op_method_class.h new file mode 100644 index 00000000000000..a7f440be2e8e6e --- /dev/null +++ b/paddle/ap/include/drr/unbound_opt_packed_ir_op_method_class.h @@ -0,0 +1,36 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/drr_value_helper.h" +#include "paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h" +#include "paddle/ap/include/drr/packed_ir_value.h" +#include "paddle/ap/include/drr/src_ptn_valid_in_ir_value.h" +#include "paddle/ap/include/drr/src_ptn_valid_out_ir_value.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/unbound_ir_value.h" +#include "paddle/ap/include/drr/unbound_opt_packed_ir_op.h" +#include "paddle/ap/include/drr/unbound_packed_ir_op.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetUnboundOptPackedIrOpClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/unbound_packed_ir_op.h b/paddle/ap/include/drr/unbound_packed_ir_op.h new file mode 100644 index 00000000000000..729fa246f5a70f --- /dev/null +++ b/paddle/ap/include/drr/unbound_packed_ir_op.h @@ -0,0 +1,68 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/packed_ir_op_declare.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/type.h" +#include "paddle/ap/include/graph/node.h" + +namespace ap::drr { + +template +struct UnboundPackedIrOpImpl { + public: + PackedIrOpDeclare op_declare; + std::string name; + bool operator==(const UnboundPackedIrOpImpl& other) const { + return this->op_declare == other.op_declare && this->name == other.name; + } +}; + +template +ADT_DEFINE_RC(UnboundPackedIrOp, UnboundPackedIrOpImpl); + +axpr::TypeImpl> +GetSrcPtnUnboundPackedIrOpClass(); + +template +struct Type>> + : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "SrcPtnUnboundPackedIrOp"; } + + static axpr::TypeImpl> GetClass() { + return GetSrcPtnUnboundPackedIrOpClass(); + } +}; + +axpr::TypeImpl> +GetResPtnUnboundPackedIrOpClass(); + +template +struct Type>> + : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "ResPtnUnboundPackedIrOp"; } + + static axpr::TypeImpl> GetClass() { + return GetResPtnUnboundPackedIrOpClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/unbound_packed_ir_value.h b/paddle/ap/include/drr/unbound_packed_ir_value.h new file mode 100644 index 00000000000000..3e3b642bffec5f --- /dev/null +++ b/paddle/ap/include/drr/unbound_packed_ir_value.h @@ -0,0 +1,52 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/type.h" + +namespace ap::drr { + +struct TensorPatternCtxImpl; + +template +struct UnboundPackedIrValueImpl { + std::string name; + std::weak_ptr tensor_pattern_ctx; + bool operator==(const UnboundPackedIrValueImpl& other) const { + return this->name == other.name && + this->tensor_pattern_ctx.lock() == other.tensor_pattern_ctx.lock(); + } +}; + +template +ADT_DEFINE_RC(UnboundPackedIrValue, UnboundPackedIrValueImpl); + +axpr::TypeImpl> +GetUnboundPackedIrValueClass(); + +template +struct Type> : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "UnboundPackedIrValue"; } + + static axpr::TypeImpl> GetClass() { + return GetUnboundPackedIrValueClass(); + } +}; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/unbound_packed_ir_value_method_class.h b/paddle/ap/include/drr/unbound_packed_ir_value_method_class.h new file mode 100644 index 00000000000000..7c961020a231ed --- /dev/null +++ b/paddle/ap/include/drr/unbound_packed_ir_value_method_class.h @@ -0,0 +1,29 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/drr_value.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/unbound_packed_ir_value.h" + +namespace ap::drr { + +axpr::TypeImpl> +GetUnboundPackedIrValueClass(); + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/value.h b/paddle/ap/include/drr/value.h new file mode 100644 index 00000000000000..7ec4b82fb3bd72 --- /dev/null +++ b/paddle/ap/include/drr/value.h @@ -0,0 +1,27 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/drr_ctx.h" +#include "paddle/ap/include/drr/node.h" + +namespace ap::drr { + +using axpr::Value; + +using Val = Value; + +} // namespace ap::drr diff --git a/paddle/ap/include/drr/value_method_class.h b/paddle/ap/include/drr/value_method_class.h new file mode 100644 index 00000000000000..a5164bc1d19966 --- /dev/null +++ b/paddle/ap/include/drr/value_method_class.h @@ -0,0 +1,38 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/drr_ctx_method_class.h" +#include "paddle/ap/include/drr/native_ir_op_declare_method_class.h" +#include "paddle/ap/include/drr/native_ir_op_method_class.h" +#include "paddle/ap/include/drr/native_ir_value_method_class.h" +#include "paddle/ap/include/drr/opt_packed_ir_op_declare_method_class.h" +#include "paddle/ap/include/drr/opt_packed_ir_op_method_class.h" +#include "paddle/ap/include/drr/packed_ir_op_declare_method_class.h" +#include "paddle/ap/include/drr/packed_ir_op_method_class.h" +#include "paddle/ap/include/drr/packed_ir_value_method_class.h" +#include "paddle/ap/include/drr/res_ptn_op_pattern_ctx_method_class.h" +#include "paddle/ap/include/drr/res_ptn_tensor_pattern_ctx_method_class.h" +#include "paddle/ap/include/drr/res_ptn_unbound_native_ir_op_method_class.h" +#include "paddle/ap/include/drr/res_ptn_unbound_packed_ir_op_method_class.h" +#include "paddle/ap/include/drr/result_pattern_ctx_method_class.h" +#include "paddle/ap/include/drr/source_pattern_ctx_method_class.h" +#include "paddle/ap/include/drr/src_ptn_op_pattern_ctx_method_class.h" +#include "paddle/ap/include/drr/src_ptn_tensor_pattern_ctx_method_class.h" +#include "paddle/ap/include/drr/src_ptn_unbound_native_ir_op_method_class.h" +#include "paddle/ap/include/drr/src_ptn_unbound_packed_ir_op_method_class.h" +#include "paddle/ap/include/drr/unbound_ir_value_method_class.h" +#include "paddle/ap/include/drr/unbound_opt_packed_ir_op_method_class.h" +#include "paddle/ap/include/drr/unbound_packed_ir_value_method_class.h" diff --git a/paddle/ap/include/env/ap_path.h b/paddle/ap/include/env/ap_path.h new file mode 100644 index 00000000000000..ec2556f0d33176 --- /dev/null +++ b/paddle/ap/include/env/ap_path.h @@ -0,0 +1,42 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" + +namespace ap::env { + +template +adt::Result VisitEachApPath(const YieldT& Yield) { + const char* ap_path_chars = std::getenv("AP_PATH"); + if (ap_path_chars == nullptr) { + return adt::Ok{}; + } + std::string ap_path(ap_path_chars); + std::string path; + std::istringstream ss(ap_path); + while (std::getline(ss, path, ':')) { + if (!path.empty()) { + ADT_LET_CONST_REF(loop_ctr, Yield(path)); + if (loop_ctr.template Has()) { + break; + } + } + } + return adt::Ok{}; +} + +} // namespace ap::env diff --git a/paddle/ap/include/fs/fs.h b/paddle/ap/include/fs/fs.h new file mode 100644 index 00000000000000..a7591ef3a9ca5d --- /dev/null +++ b/paddle/ap/include/fs/fs.h @@ -0,0 +1,64 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/ap/include/adt/adt.h" + +namespace ap::fs { + +inline bool FileExists(const std::string& filepath) { + std::fstream fp; + fp.open(filepath, std::fstream::in); + if (fp.is_open()) { + fp.close(); + return true; + } else { + return false; + } +} + +// reference: +// https://stackoverflow.com/questions/2602013/read-whole-ascii-file-into-c-stdstring +inline adt::Result ReadFileContent(const std::string& file_path, + std::string* content) { + std::ifstream ifs(file_path); + + ADT_CHECK(ifs.is_open()) << adt::errors::RuntimeError{ + std::string() + "file open failed. file_path: " + file_path}; + + ifs.seekg(0, std::ios::end); + content->reserve(ifs.tellg()); + ifs.seekg(0, std::ios::beg); + + content->assign(std::istreambuf_iterator(ifs), + std::istreambuf_iterator()); + return adt::Ok{}; +} + +inline adt::Result WriteFileContent(const std::string& file_path, + const std::string& content) { + std::ofstream ofs{file_path}; + ADT_CHECK(ofs.is_open()) << adt::errors::RuntimeError{ + std::string() + "file open failed. file_path: " + file_path}; + ofs << content; + ofs.close(); + return adt::Ok{}; +} + +} // namespace ap::fs diff --git a/paddle/ap/include/graph/adt.h b/paddle/ap/include/graph/adt.h new file mode 100644 index 00000000000000..08764277dd151a --- /dev/null +++ b/paddle/ap/include/graph/adt.h @@ -0,0 +1,20 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/adt/bfs_walker.h" +#include "paddle/ap/include/adt/topo_walker.h" + +namespace ap::graph {} // namespace ap::graph diff --git a/paddle/ap/include/graph/graph_descriptor.h b/paddle/ap/include/graph/graph_descriptor.h new file mode 100644 index 00000000000000..1e1f3f6a0f83e4 --- /dev/null +++ b/paddle/ap/include/graph/graph_descriptor.h @@ -0,0 +1,49 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "glog/logging.h" +#include "paddle/ap/include/graph/adt.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_topo_cstr.h" + +namespace ap::graph { + +template +struct GraphDescriptor { + GraphDescriptor(const GraphDescriptor&) = default; + GraphDescriptor(GraphDescriptor&&) = default; + + template + adt::Result VisitUpstreamNodes(const NodeT&, + const DoEachT& DoEach) const; + + template + adt::Result VisitDownstreamNodes(const NodeT&, + const DoEachT& DoEach) const; + + adt::Result GetSmallGraphNodeTopoCstr( + const NodeT&) const; + + adt::Result IgnoredNode(const NodeT&) const; + + adt::Result IsOpNode(const NodeT&) const; + + adt::Result TopoSatisfy(const NodeT&, + const graph::SmallGraphNodeTopoCstr&) const; +}; + +} // namespace ap::graph diff --git a/paddle/ap/include/graph/graph_helper.h b/paddle/ap/include/graph/graph_helper.h new file mode 100644 index 00000000000000..3d7e5897c8b1cc --- /dev/null +++ b/paddle/ap/include/graph/graph_helper.h @@ -0,0 +1,178 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "glog/logging.h" +#include "paddle/ap/include/graph/adt.h" +#include "paddle/ap/include/graph/graph_descriptor.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_arena.h" + +namespace ap::graph { + +template +struct GraphHelper { + explicit GraphHelper(const GraphDescriptor& graph_descriptor) + : graph_descriptor_(graph_descriptor) {} + + GraphHelper(const GraphHelper&) = delete; + GraphHelper(GraphHelper&&) = delete; + + adt::Result FindAnchor(const NodeT& start) { + const auto& True = [](const auto&) -> adt::Result { return true; }; + ADT_LET_CONST_REF(opt_anchor, FilterAnchor(start, True)); + ADT_CHECK(opt_anchor.has_value()) << adt::errors::MismatchError{}; + return opt_anchor.value(); + } + + template + adt::Result> FilterAnchor(const NodeT& start, + const FilterT& Filter) { + const auto topo_walker = GetTopoWalker(); + const auto IsSource = [&](const NodeT& sg_node) -> adt::Result { + bool has_source = false; + auto SetHasSource = [&](const NodeT&) -> adt::Result { + has_source = true; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + graph_descriptor_.VisitUpstreamNodes(sg_node, SetHasSource)); + return !has_source; + }; + const auto IsSink = [&](const NodeT& sg_node) -> adt::Result { + bool has_sink = false; + auto SetHasSink = [&](const NodeT&) -> adt::Result { + has_sink = true; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + graph_descriptor_.VisitDownstreamNodes(sg_node, SetHasSink)); + return !has_sink; + }; + std::unordered_set source_or_sinks; + auto CollectStarts = [&](const NodeT& sg_node) -> adt::Result { + ADT_LET_CONST_REF(ignored, graph_descriptor_.IgnoredNode(sg_node)); + if (ignored) { + return adt::Ok{}; + } + ADT_LET_CONST_REF(is_source, IsSource(sg_node)); + ADT_LET_CONST_REF(is_sink, IsSink(sg_node)); + if (is_source || is_sink) { + source_or_sinks.insert(sg_node); + } + return adt::Ok{}; + }; + const auto bfs_walker_without_ignore = GetBfsWalkerWithoutIgnore(); + ADT_RETURN_IF_ERR(bfs_walker_without_ignore(start, CollectStarts)); + ADT_CHECK(source_or_sinks.size() > 0); + std::unordered_map node2depth; + std::map> depth2nodes; + const auto bfs_walker = GetBfsWalker(); + auto UpdateNodeDepth = [&](const NodeT& sg_node) -> adt::Result { + size_t max_depth = 0; + ADT_RETURN_IF_ERR(bfs_walker.VisitNextNodes( + sg_node, [&](const NodeT& prev) -> adt::Result { + const auto& iter = node2depth.find(prev); + if (iter != node2depth.end()) { + max_depth = std::max(max_depth, iter->second); + } + return adt::Ok{}; + })); + node2depth[sg_node] = max_depth; + depth2nodes[max_depth].push_back(sg_node); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(bfs_walker( + source_or_sinks.begin(), source_or_sinks.end(), UpdateNodeDepth)); + for (auto iter = depth2nodes.rbegin(); iter != depth2nodes.rend(); ++iter) { + for (const auto& node : iter->second) { + ADT_LET_CONST_REF(is_op_node, this->graph_descriptor_.IsOpNode(node)); + if (is_op_node) { + ADT_LET_CONST_REF(filter_success, Filter(node)); + if (filter_success) { + return node; + } + } + } + } + return std::nullopt; + } + + adt::BfsWalker GetBfsWalker() { + auto graph = this->graph_descriptor_; + const auto& ForEachNext = + [graph](const NodeT& node, + const auto& VisitNext) -> adt::Result { + auto DoEach = [&](const NodeT& next) -> adt::Result { + ADT_LET_CONST_REF(is_ignored, graph.IgnoredNode(next)); + if (is_ignored) { + return adt::Ok{}; + } + return VisitNext(next); + }; + ADT_RETURN_IF_ERR(graph.VisitDownstreamNodes(node, DoEach)); + ADT_RETURN_IF_ERR(graph.VisitUpstreamNodes(node, DoEach)); + return adt::Ok{}; + }; + return adt::BfsWalker(ForEachNext); + } + + adt::BfsWalker GetBfsWalkerWithoutIgnore() { + auto graph = this->graph_descriptor_; + const auto& ForEachNext = + [graph](const NodeT& node, + const auto& VisitNext) -> adt::Result { + ADT_RETURN_IF_ERR(graph.VisitDownstreamNodes(node, VisitNext)); + ADT_RETURN_IF_ERR(graph.VisitUpstreamNodes(node, VisitNext)); + return adt::Ok{}; + }; + return adt::BfsWalker(ForEachNext); + } + + adt::TopoWalker GetTopoWalker() { + auto graph = this->graph_descriptor_; + const auto& ForEachPrev = + [graph](const NodeT& node, + const auto& VisitPrev) -> adt::Result { + auto DoEach = [&](const NodeT& prev) -> adt::Result { + ADT_LET_CONST_REF(is_ignored, graph.IgnoredNode(prev)); + if (is_ignored) { + return adt::Ok{}; + } + return VisitPrev(prev); + }; + return graph.VisitUpstreamNodes(node, DoEach); + }; + const auto& ForEachNext = + [graph](const NodeT& node, + const auto& VisitNext) -> adt::Result { + auto DoEach = [&](const NodeT& next) -> adt::Result { + ADT_LET_CONST_REF(is_ignored, graph.IgnoredNode(next)); + if (is_ignored) { + return adt::Ok{}; + } + return VisitNext(next); + }; + return graph.VisitDownstreamNodes(node, DoEach); + }; + return adt::TopoWalker(ForEachPrev, ForEachNext); + } + + private: + GraphDescriptor graph_descriptor_; +}; + +} // namespace ap::graph diff --git a/paddle/ap/include/graph/node.h b/paddle/ap/include/graph/node.h new file mode 100644 index 00000000000000..cf4db572f624c4 --- /dev/null +++ b/paddle/ap/include/graph/node.h @@ -0,0 +1,86 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/graph/node_list.h" +#include "paddle/ap/include/graph/tags.h" + +namespace ap::graph { + +template +class NodeArena; + +template +struct Node { + tNodeId node_id_; + std::weak_ptr> node_arena_; + + const tNodeId& node_id() const { return node_id_; } + std::weak_ptr> node_arena() const { return node_arena_; } + + adt::Result>> GetNodeArena() const { + auto ptr = node_arena_.lock(); + if (!ptr) { + return adt::errors::RuntimeError{"NodeArena is delete."}; + } + return ptr; + } + + adt::Result Get() const; + adt::Result> DownstreamNodes() const; + adt::Result> UpstreamNodes() const; + + adt::Result ConnectTo( + const Node& dst_node, + const ValidListTag& src_downstream_type, + const ValidListTag& dst_unstream_type) const; + + bool operator<(const Node& other) const { + if (!(this->node_id_ == other.node_id_)) { + return this->node_id_.value() < other.node_id_.value(); + } + return this->node_arena_.lock() < other.node_arena_.lock(); + } + + bool operator==(const Node& other) const { + return other.node_id_.value() == this->node_id_.value() && + other.node_arena_.lock() == this->node_arena_.lock(); + } + + bool operator!=(const Node& other) const { return !(*this == other); } + + std::size_t GetHashValue() const { + return adt::hash_combine( + this->node_id_.value(), + std::hash>>()(this->node_arena_.lock())); + } +}; + +} // namespace ap::graph + +namespace std { + +template +struct hash> { + std::size_t operator()(const ap::graph::Node& node) const { + return node.GetHashValue(); + } +}; + +} // namespace std diff --git a/paddle/ap/include/graph/node_arena.h b/paddle/ap/include/graph/node_arena.h new file mode 100644 index 00000000000000..acdcdacc7e3cd8 --- /dev/null +++ b/paddle/ap/include/graph/node_arena.h @@ -0,0 +1,179 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/tags.h" + +namespace ap::graph { + +template +class NodeArena : public std::enable_shared_from_this> { + public: + NodeArena() {} + NodeArena(const NodeArena&) = delete; + NodeArena(NodeArena&&) = delete; + + adt::Result At(const tNodeId& node_id) const { + if (node_id.value() >= nodes_.size()) { + return adt::errors::IndexError{"node_id out of ranges."}; + } + return nodes_.at(node_id.value()); + } + + template + const T& New(const ConstructorT& Constructor) { + tNodeId node_id{nodes_.size()}; + const auto& node = Constructor(Node{node_id, this->shared_from_this()}); + return EmplaceBackNode(node); + } + + template + adt::Result TryNew(const ConstructorT& Constructor) { + tNodeId node_id{nodes_.size()}; + ADT_LET_CONST_REF(node, + Constructor(Node{node_id, this->shared_from_this()})); + return EmplaceBackNode(node); + } + + adt::Result> DownstreamNodes4SrcNodeId( + const tNodeId& src_id) { + if (src_id.value() >= src_node_id2downstream_nodes_.size()) { + return adt::errors::IndexError{"src node_id out of ranges."}; + } + return src_node_id2downstream_nodes_.at(src_id.value()); + } + + adt::Result> UpstreamNodes4DstNodeId( + const tNodeId& dst_id) { + if (dst_id.value() >= dst_node_id2upstream_nodes_.size()) { + return adt::errors::IndexError{"dst node_id out of ranges."}; + } + return dst_node_id2upstream_nodes_.at(dst_id.value()); + } + + adt::Result Connect( + const Node& src_node, + const ValidListTag& src_downstream_type, + const Node& dst_node, + const ValidListTag& dst_unstream_type) { + const auto& src_id = src_node.node_id(); + if (src_node.node_arena().lock() != this->shared_from_this()) { + return adt::errors::RuntimeError{ + "Connection between nodes from different arena is not supported. "}; + } + if (src_id.value() >= this->src_node_id2downstream_nodes_.size()) { + return adt::errors::IndexError{ + "src_id.value() is out of range " + "this->src_node_id2downstream_nodes_."}; + } + const auto& dst_id = dst_node.node_id(); + if (dst_node.node_arena().lock() != this->shared_from_this()) { + return adt::errors::RuntimeError{ + "Connection between nodes from different arena is not supported. "}; + } + if (dst_id.value() >= this->dst_node_id2upstream_nodes_.size()) { + return adt::errors::IndexError{ + "src_id.value() is out of range this->dst_node_id2upstream_nodes_."}; + } + ADT_LET_CONST_REF( + downstream_nodes_data, + GetNodeListData(&src_node_id2downstream_nodes_[src_id.value()], + src_downstream_type)); + downstream_nodes_data->emplace_back( + Node{dst_id, this->shared_from_this()}); + ADT_LET_CONST_REF( + upstream_nodes_data, + GetNodeListData(&dst_node_id2upstream_nodes_[dst_id.value()], + dst_unstream_type)); + upstream_nodes_data->emplace_back( + Node{src_id, this->shared_from_this()}); + return adt::Ok{}; + } + + const std::vector& nodes() const { return nodes_; } + + private: + adt::Result>> GetNodeListData( + NodeList* node_list, const ValidListTag& type) { + using RetDataT = adt::List>; + using RetT = adt::Result; + if (node_list->template Has>()) { + return type.Match( + [&](const IndexedTag&) -> RetT { + IndexedTag data{RetDataT{}}; + *node_list = data; + return data.data; + }, + [&](const UnindexedTag&) -> RetT { + UnindexedTag data{RetDataT{}}; + *node_list = data; + return data.data; + }); + } + const auto& pattern_match = ::common::Overloaded{ + [&](const IndexedTag& l, + const IndexedTag&) -> RetT { return l.data; }, + [&](const UnindexedTag& l, + const UnindexedTag&) -> RetT { return l.data; }, + [&](const auto&, const auto&) -> RetT { + return adt::errors::TypeError{"ap graph node list type mismatch."}; + }}; + return std::visit(pattern_match, node_list->variant(), type.variant()); + } + + const T& EmplaceBackNode(const T& node) { + nodes_.emplace_back(node); + src_node_id2downstream_nodes_.resize(nodes_.size()); + dst_node_id2upstream_nodes_.resize(nodes_.size()); + return nodes_.at(nodes_.size() - 1); + } + + std::vector nodes_; + std::vector> src_node_id2downstream_nodes_; + std::vector> dst_node_id2upstream_nodes_; +}; + +template +adt::Result Node::Get() const { + ADT_LET_CONST_REF(arena, adt::WeakPtrLock(this->node_arena())); + return arena->At(this->node_id()); +} + +template +adt::Result> Node::DownstreamNodes() const { + ADT_LET_CONST_REF(arena, adt::WeakPtrLock(this->node_arena())); + return arena->DownstreamNodes4SrcNodeId(this->node_id()); +} + +template +adt::Result> Node::UpstreamNodes() const { + ADT_LET_CONST_REF(arena, adt::WeakPtrLock(this->node_arena())); + return arena->UpstreamNodes4DstNodeId(this->node_id()); +} + +template +adt::Result Node::ConnectTo( + const Node& dst_node, + const ValidListTag& src_downstream_type, + const ValidListTag& dst_unstream_type) const { + ADT_LET_CONST_REF(arena, adt::WeakPtrLock(this->node_arena())); + return arena->Connect( + *this, src_downstream_type, dst_node, dst_unstream_type); +} + +} // namespace ap::graph diff --git a/paddle/ap/include/graph/node_descriptor.h b/paddle/ap/include/graph/node_descriptor.h new file mode 100644 index 00000000000000..824144679b7f6e --- /dev/null +++ b/paddle/ap/include/graph/node_descriptor.h @@ -0,0 +1,36 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/ap/include/graph/adt.h" +#include "paddle/ap/include/graph/node.h" + +namespace ap::graph { + +template +struct NodeDescriptor; + +template +struct NodeDescriptorInterface { + std::string DebugId(const NodeT&); + + template + adt::Result AttrsSatisfyIfBothAreOpsOrValues( + const NodeT& node, const DrrGraphNodeT& drr_node); +}; + +} // namespace ap::graph diff --git a/paddle/ap/include/graph/node_list.h b/paddle/ap/include/graph/node_list.h new file mode 100644 index 00000000000000..56a0553daa4370 --- /dev/null +++ b/paddle/ap/include/graph/node_list.h @@ -0,0 +1,113 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/graph/tags.h" + +namespace ap::graph { + +template +struct Node; + +template +struct UndefinedTag : public std::monostate { + using std::monostate::monostate; +}; + +template +struct IndexedTag { + T data; +}; + +template +struct UnindexedTag { + T data; +}; + +template +using ValidListTagImpl = std::variant, UnindexedTag>; + +template +struct ValidListTag : public ValidListTagImpl { + using ValidListTagImpl::ValidListTagImpl; + ADT_DEFINE_VARIANT_METHODS(ValidListTagImpl); +}; + +template +using ListTagImpl = + std::variant, IndexedTag, UnindexedTag>; + +template +struct ListTag : public ListTagImpl { + using ListTagImpl::ListTagImpl; + ADT_DEFINE_VARIANT_METHODS(ListTagImpl); +}; + +template +struct NodeList : public ListTag>> { + using list_type = adt::List>; + + using ListTag::ListTag; + + ListTag type() const { + return this->Match( + [](const UndefinedTag&) -> ListTag { + return UndefinedTag{}; + }, + [](const IndexedTag&) -> ListTag { + return IndexedTag{}; + }, + [](const UnindexedTag&) -> ListTag { + return UnindexedTag{}; + }); + } + + adt::Result> Sole() const { + return this->Match( + [](const UndefinedTag&) -> adt::Result> { + return adt::errors::TypeError{"UndefinedList has no sole data"}; + }, + [](const auto& l) -> adt::Result> { + ADT_CHECK(l.data->size(), 1); + return l.data->at(0); + }); + } + + std::size_t size() const { + return this->Match( + [](const UndefinedTag&) -> std::size_t { return 0; }, + [](const auto& l) -> std::size_t { return l.data->size(); }); + } + + template + adt::Result VisitNodes(const DoEachT& DoEach) const { + return this->Match( + [](const UndefinedTag&) -> adt::Result { + return adt::Ok{}; + }, + [&](const auto& l) -> adt::Result { + for (const auto& data : *l.data) { + ADT_RETURN_IF_ERR(DoEach(data)); + } + return adt::Ok{}; + }); + } +}; + +} // namespace ap::graph diff --git a/paddle/ap/include/graph/node_topo_cstr.h b/paddle/ap/include/graph/node_topo_cstr.h new file mode 100644 index 00000000000000..a2c2761bbded0e --- /dev/null +++ b/paddle/ap/include/graph/node_topo_cstr.h @@ -0,0 +1,180 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/graph/adt.h" + +namespace ap::graph { + +struct NativeIrValueTopoCstr : public std::monostate { + using std::monostate::monostate; +}; + +struct NativeIrOpTopoCstr { + std::string op_name; + + bool operator==(const NativeIrOpTopoCstr& other) const { + return this->op_name == other.op_name; + } +}; + +struct NativeIrOpOperandTopoCstr { + std::size_t index; + + bool operator==(const NativeIrOpOperandTopoCstr& other) const { + return this->index == other.index; + } +}; + +struct NativeIrOpResultTopoCstr { + std::size_t index; + + bool operator==(const NativeIrOpResultTopoCstr& other) const { + return this->index == other.index; + } +}; + +struct PackedIrValueTopoCstr { + bool operator==(const PackedIrValueTopoCstr&) const { return false; } + bool operator!=(const PackedIrValueTopoCstr&) const { return false; } +}; + +struct PackedIrOpTopoCstr { + std::string op_name; + + bool operator==(const PackedIrOpTopoCstr& other) const { + return this->op_name == other.op_name; + } +}; + +struct PackedIrOpOperandTopoCstr : public std::monostate { + using std::monostate::monostate; +}; + +struct PackedIrOpResultTopoCstr : public std::monostate { + using std::monostate::monostate; +}; + +struct OptPackedIrOpTopoCstr { + PackedIrOpTopoCstr packed_ir_op_topo_cstr; + + bool operator==(const OptPackedIrOpTopoCstr& other) const { + return this->packed_ir_op_topo_cstr == other.packed_ir_op_topo_cstr; + } +}; + +struct OptPackedIrOpOperandTopoCstr { + PackedIrOpOperandTopoCstr packed_ir_op_operand_topo_cstr; + + bool operator==(const OptPackedIrOpOperandTopoCstr& other) const { + return this->packed_ir_op_operand_topo_cstr == + other.packed_ir_op_operand_topo_cstr; + } +}; + +struct OptPackedIrOpResultTopoCstr : public std::monostate { + PackedIrOpResultTopoCstr packed_ir_op_result_topo_cstr; + + bool operator==(const OptPackedIrOpResultTopoCstr& other) const { + return this->packed_ir_op_result_topo_cstr == + other.packed_ir_op_result_topo_cstr; + } +}; + +struct RefIrValueTopoCstr : public std::monostate { + using std::monostate::monostate; +}; + +struct RefIrOpTopoCstr : public std::monostate { + using std::monostate::monostate; +}; + +struct RefIrOpOperandTopoCstr : public std::monostate { + using std::monostate::monostate; +}; + +struct RefIrOpResultTopoCstr : public std::monostate { + using std::monostate::monostate; +}; + +using NodeTopoCstrImpl = std::variant; +// node constraint +struct NodeTopoCstr : public NodeTopoCstrImpl { + using NodeTopoCstrImpl::NodeTopoCstrImpl; + ADT_DEFINE_VARIANT_METHODS(NodeTopoCstrImpl); + + adt::Result TopoSatisfy(const NodeTopoCstr& sg_node_topo_cstr) const { + using RetT = adt::Result; + const auto& pattern_match = ::common::Overloaded{ + [&](const PackedIrOpTopoCstr& bg_topo_cstr, + const OptPackedIrOpTopoCstr& sg_topo_cstr) -> RetT { + return bg_topo_cstr == sg_topo_cstr.packed_ir_op_topo_cstr; + }, + [&](const PackedIrOpOperandTopoCstr& bg_topo_cstr, + const OptPackedIrOpOperandTopoCstr& sg_topo_cstr) -> RetT { + return bg_topo_cstr == sg_topo_cstr.packed_ir_op_operand_topo_cstr; + }, + [&](const PackedIrOpResultTopoCstr& bg_topo_cstr, + const OptPackedIrOpResultTopoCstr& sg_topo_cstr) -> RetT { + return bg_topo_cstr == sg_topo_cstr.packed_ir_op_result_topo_cstr; + }, + [&](const RefIrValueTopoCstr& bg_topo_cstr, + const NativeIrValueTopoCstr& sg_topo_cstr) -> RetT { return true; }, + [&](const RefIrOpTopoCstr& bg_topo_cstr, + const OptPackedIrOpTopoCstr& sg_topo_cstr) -> RetT { return true; }, + [&](const RefIrOpOperandTopoCstr& bg_topo_cstr, + const OptPackedIrOpOperandTopoCstr& sg_topo_cstr) -> RetT { + return true; + }, + [&](const RefIrOpResultTopoCstr& bg_topo_cstr, + const OptPackedIrOpResultTopoCstr& sg_topo_cstr) -> RetT { + return true; + }, + [&](const auto&, const auto&) -> RetT { + return *this == sg_node_topo_cstr; + }}; + return std::visit( + pattern_match, this->variant(), sg_node_topo_cstr.variant()); + } +}; + +struct SmallGraphNodeTopoCstr { + NodeTopoCstr node_topo_cstr; +}; + +struct BigGraphNodeTopoCstr { + NodeTopoCstr node_topo_cstr; + + adt::Result TopoSatisfy( + const SmallGraphNodeTopoCstr& sg_node_topo_cstr) const { + return this->node_topo_cstr.TopoSatisfy(sg_node_topo_cstr.node_topo_cstr); + } +}; + +} // namespace ap::graph diff --git a/paddle/ap/include/graph/tags.h b/paddle/ap/include/graph/tags.h new file mode 100644 index 00000000000000..d5cca7b58782a1 --- /dev/null +++ b/paddle/ap/include/graph/tags.h @@ -0,0 +1,28 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" + +namespace ap::graph { + +// graph node +ADT_DEFINE_TAG(tNodeId); + +// dst node input index + +ADT_DEFINE_TAG(tDstInIdx); + +} // namespace ap::graph diff --git a/paddle/ap/include/index_expr/builtin_frame_util.h b/paddle/ap/include/index_expr/builtin_frame_util.h new file mode 100644 index 00000000000000..f4cb67dbb0a989 --- /dev/null +++ b/paddle/ap/include/index_expr/builtin_frame_util.h @@ -0,0 +1,44 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/builtin_frame_util.h" +#include "paddle/ap/include/index_expr/value_method_class.h" + +namespace ap::index_expr { + +template +void VisitEachBuiltinFrameClass(const DoEachT& DoEach) { + DoEach(axpr::GetDimExprClass()); + DoEach(GetSliceClass()); + DoEach(GetIndexExprClass()); + DoEach(GetInIndexTupleExprSignatureClass()); + DoEach(GetOutIndexTupleExprSignatureClass()); + DoEach(GetOpIndexTupleExprSignatureClass()); +} + +template +ap::axpr::AttrMap MakeBuiltinFrameAttrMap() { + ap::axpr::AttrMap attr_map; + ap::axpr::VisitEachBuiltinFrameAttr( + [&](const std::string& k, const ValueT& v) { attr_map->Set(k, v); }); + VisitEachBuiltinFrameClass( + [&](const auto& cls) { attr_map->Set(cls.Name(), cls); }); + return attr_map; +} + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/dim_expr_cuda_code_generator.h b/paddle/ap/include/index_expr/dim_expr_cuda_code_generator.h new file mode 100644 index 00000000000000..67fd61ded974e4 --- /dev/null +++ b/paddle/ap/include/index_expr/dim_expr_cuda_code_generator.h @@ -0,0 +1,163 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/common/unique_id.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" + +namespace ap::index_expr { + +class DimExprCudaCodeGenerator { + public: + using ArgName4DimExprT = + std::function(const symbol::DimExpr&)>; + explicit DimExprCudaCodeGenerator(std::ostringstream* ss, + const ArgName4DimExprT& ArgName4DimExprVal, + const std::string& index_type_name) + : ss_(ss), + ArgName4DimExpr(ArgName4DimExprVal), + index_type_name_(index_type_name) {} + + std::ostringstream& ss() { return *ss_; } + + adt::Result CodeGen(const symbol::DimExpr& dim_expr) { + if (const auto& arg_name = ArgName4DimExpr(dim_expr)) { + return arg_name.value(); + } + return dim_expr.Match([&](const auto& impl) { return CodeGenImpl(impl); }); + } + + private: + adt::Result CodeGenImpl(int64_t c) { return std::to_string(c); } + + adt::Result CodeGenImpl(const std::string& var) { + return adt::errors::TypeError{ + std::string() + "no kernel argument bound to DimExpr '" + var + "'"}; + } + + adt::Result CodeGenImpl( + const symbol::Negative& dim_expr) { + const auto& [operand] = *dim_expr; + ADT_LET_CONST_REF(operand_str, CodeGen(operand)); + return std::string() + "(-" + operand_str + ")"; + } + + adt::Result CodeGenImpl( + const symbol::Reciprocal&) { + return adt::errors::ValueError{ + "reciprocal value should be processed in '*'"}; + } + + adt::Result CodeGenImpl( + const symbol::Add& dim_expr) { + ADT_CHECK(dim_expr.operands->size() > 0); + ADT_LET_CONST_REF(first, CodeGen(dim_expr.operands->at(0))); + std::string ret = first; + for (int i = 1; i < dim_expr.operands->size(); ++i) { + const auto& operand = dim_expr.operands->at(i); + if (operand.Has>()) { + const auto& [negtaive_operand] = + *operand.Get>(); + ADT_LET_CONST_REF(operand_str, CodeGen(negtaive_operand)); + ret += " - " + operand_str; + } else { + ADT_LET_CONST_REF(operand_str, CodeGen(operand)); + ret += " + " + operand_str; + } + } + return std::string() + "(" + ret + ")"; + } + + adt::Result CodeGenImpl( + const symbol::Mul& dim_expr) { + ADT_CHECK(dim_expr.operands->size() > 0); + ADT_LET_CONST_REF(first, CodeGen(dim_expr.operands->at(0))); + std::string ret = first; + for (int i = 1; i < dim_expr.operands->size(); ++i) { + const auto& operand = dim_expr.operands->at(i); + if (operand.Has>()) { + const auto& [negtaive_operand] = + *operand.Get>(); + ADT_LET_CONST_REF(operand_str, CodeGen(negtaive_operand)); + ret += " / " + operand_str; + } else { + ADT_LET_CONST_REF(operand_str, CodeGen(operand)); + ret += " * " + operand_str; + } + } + return std::string() + "(" + ret + ")"; + } + + adt::Result CodeGenImpl( + const symbol::Max& dim_expr) { + ADT_CHECK(dim_expr.operands->size() > 0); + ADT_LET_CONST_REF(first, CodeGen(dim_expr.operands->at(0))); + const std::string& var_name = ap::common::NewUniqueId("_ap_sym"); + ss() << index_type_name_ << " " << var_name << " = " << first << ";\n"; + for (int i = 1; i < dim_expr.operands->size(); ++i) { + const auto& operand = dim_expr.operands->at(i); + const std::string& operand_var_name = ap::common::NewUniqueId("_ap_sym"); + ADT_LET_CONST_REF(operand_str, CodeGen(operand)); + ss() << index_type_name_ << " " << operand_var_name << " = " + << operand_str << ";\n"; + ss() << var_name << " = (" << operand_var_name << " > " << var_name + << " ? " << operand_var_name << " : " << var_name << ");\n"; + } + return var_name; + } + + adt::Result CodeGenImpl( + const symbol::Min& dim_expr) { + ADT_CHECK(dim_expr.operands->size() > 0); + ADT_LET_CONST_REF(first, CodeGen(dim_expr.operands->at(0))); + const std::string& var_name = ap::common::NewUniqueId("_ap_sym"); + ss() << index_type_name_ << " " << var_name << " = " << first << ";\n"; + for (int i = 1; i < dim_expr.operands->size(); ++i) { + const auto& operand = dim_expr.operands->at(i); + const std::string& operand_var_name = ap::common::NewUniqueId("_ap_sym"); + ADT_LET_CONST_REF(operand_str, CodeGen(operand)); + ss() << index_type_name_ << " " << operand_var_name << " = " + << operand_str << ";\n"; + ss() << var_name << " = (" << operand_var_name << " < " << var_name + << " ? " << operand_var_name << " : " << var_name << ");\n"; + } + return var_name; + } + + adt::Result CodeGenImpl( + const symbol::Broadcast& dim_expr) { + ADT_CHECK(dim_expr.operands->size() > 0); + ADT_LET_CONST_REF(first, CodeGen(dim_expr.operands->at(0))); + const std::string& var_name = ap::common::NewUniqueId("_ap_sym"); + ss() << index_type_name_ << " " << var_name << " = " << first << ";\n"; + for (int i = 1; i < dim_expr.operands->size(); ++i) { + const auto& operand = dim_expr.operands->at(i); + const std::string& operand_var_name = ap::common::NewUniqueId("_ap_sym"); + ADT_LET_CONST_REF(operand_str, CodeGen(operand)); + ss() << index_type_name_ << " " << operand_var_name << " = " + << operand_str << ";\n"; + ss() << var_name << " = (" << operand_var_name << " > " << var_name + << " ? " << operand_var_name << " : " << var_name << ");\n"; + } + return var_name; + } + + std::ostringstream* ss_; + ArgName4DimExprT ArgName4DimExpr; + std::string index_type_name_; +}; + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_closure.h b/paddle/ap/include/index_expr/index_closure.h new file mode 100644 index 00000000000000..4d1c133b494880 --- /dev/null +++ b/paddle/ap/include/index_expr/index_closure.h @@ -0,0 +1,103 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/index_expr/index_expr.h" +#include "paddle/ap/include/index_expr/index_expr_interpreter.h" +#include "paddle/ap/include/index_expr/op_index_tuple_expr_signature.h" +#include "paddle/ap/include/index_expr/value.h" +#include "paddle/ap/include/index_expr/value_method_class.h" + +namespace ap::index_expr { + +using axpr::CoreExpr; +using axpr::Lambda; + +struct IndexClosureData { + const ap::index_expr::Val ctx; + const adt::List inputs_meta; + const adt::List outputs_meta; + const adt::List in_vars; + + bool operator==(const IndexClosureData& other) const { + return other.ctx == this->ctx && other.inputs_meta == this->inputs_meta && + other.outputs_meta == this->outputs_meta && + other.in_vars == this->in_vars; + } +}; + +using Nice2IndexLambdas = + std::map>>; + +struct OrderedOneofIndexClosureImpl { + std::shared_ptr interpreter; + IndexClosureData closure_data; + Nice2IndexLambdas nice2index_lambdas; + + adt::Result operator()( + const IndexTupleExpr&) const; + + bool operator==(const OrderedOneofIndexClosureImpl& other) const { + return other.interpreter == this->interpreter && + other.closure_data == this->closure_data && + other.nice2index_lambdas == this->nice2index_lambdas; + } + + private: + adt::Result CallLambda( + const Lambda& lambda, const IndexTupleExpr&) const; +}; +ADT_DEFINE_RC(OrderedOneofIndexClosure, OrderedOneofIndexClosureImpl); + +using TrackedIndexesTransformImpl = + std::variant; + +struct TrackedIndexesTransform : public TrackedIndexesTransformImpl { + using TrackedIndexesTransformImpl::TrackedIndexesTransformImpl; + ADT_DEFINE_VARIANT_METHODS(TrackedIndexesTransformImpl); +}; + +using OpIndexesTransformSignature = + ap::index_expr::OpSignature; + +struct RecordableIndexClosureImpl { + OpIndexesTransformSignature op_indexes_transform_signature; + + adt::Result operator()( + const IndexTupleExpr&) const; + + bool operator==(const RecordableIndexClosureImpl& other) const { + return other.op_indexes_transform_signature == + this->op_indexes_transform_signature; + } +}; +ADT_DEFINE_RC(RecordableIndexClosure, RecordableIndexClosureImpl); + +using IndexClosureImpl = + std::variant; + +struct IndexClosure : public IndexClosureImpl { + using IndexClosureImpl::IndexClosureImpl; + ADT_DEFINE_VARIANT_METHODS(IndexClosureImpl); + + adt::Result operator()( + const IndexTupleExpr& indexes_expr) const; +}; + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_expr.h b/paddle/ap/include/index_expr/index_expr.h new file mode 100644 index 00000000000000..6f1d0c9f0e3d8b --- /dev/null +++ b/paddle/ap/include/index_expr/index_expr.h @@ -0,0 +1,171 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/index_expr/slice.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" + +namespace ap::index_expr { + +struct IndexTupleExpr; + +std::string IndexTupleExprToString(const std::shared_ptr&); + +struct UndefinedIndexExprImpl : public std::monostate { + using std::monostate::monostate; + + std::string ToString() const { return "IndexExpr.Undefined"; } +}; + +ADT_DEFINE_RC(UndefinedIndexExpr, UndefinedIndexExprImpl); + +struct PtrGetItemImpl { + std::string ptr_var_name; + std::shared_ptr indexes_expr; + symbol::DimExpr range; + + bool operator==(const PtrGetItemImpl& other) const { + return (other.ptr_var_name == this->ptr_var_name) && + other.indexes_expr == this->indexes_expr && + other.range == this->range; + } + + std::string ToString() const { + return std::string() + "IndexExpr.PtrGetItem(ptr_var_name=" + ptr_var_name + + ", indexes_expr=" + IndexTupleExprToString(indexes_expr) + + ", range=" + symbol::ToString(range) + ")"; + } +}; + +ADT_DEFINE_RC(PtrGetItem, PtrGetItemImpl); + +struct IndexExprDomainImpl { + symbol::DimExpr range; + + bool operator==(const IndexExprDomainImpl& other) const { + return other.range == this->range; + } + + std::string ToString() const { + return std::string() + "IndexExpr.Domain(" + symbol::ToString(range) + ")"; + } +}; + +ADT_DEFINE_RC(IndexExprDomain, const IndexExprDomainImpl); + +template +struct IndexExprBroadcastMaskImpl { + symbol::DimExpr dim; + Expr index_expr; + + bool operator==(const IndexExprBroadcastMaskImpl& other) const { + return other.dim == this->dim && other.index_expr == this->index_expr; + } + + std::string ToString() const { + return std::string() + + "IndexExpr.BroadcastMask(dim=" + symbol::ToString(dim) + + ", index_expr=" + index_expr.ToString() + ")"; + } +}; + +template +ADT_DEFINE_RC(IndexExprBroadcastMask, const IndexExprBroadcastMaskImpl); + +// IndexExprSlice * IndexExprAffine == IdentityFunc if fields are same. +template +struct IndexExprSliceImpl { + index_expr::Slice slice; + symbol::DimExpr range; + Expr index_expr; + + bool operator==(const IndexExprSliceImpl& other) const { + return (other.slice == this->slice) && (other.range == this->range) && + (other.index_expr == this->index_expr); + } + + std::string ToString() const { + return index_expr.ToString() + ".slice(" + slice->ToString() + + ", range=" + symbol::ToString(range) + ")"; + } +}; + +template +ADT_DEFINE_RC(IndexExprSlice, const IndexExprSliceImpl); + +template +struct IndexExprAffineImpl { + index_expr::Slice slice; + symbol::DimExpr range; + Expr index_expr; + + bool operator==(const IndexExprAffineImpl& other) const { + return (other.slice == this->slice) && (other.range == this->range) && + (other.index_expr == this->index_expr); + } + + std::string ToString() const { + return index_expr.ToString() + ".affine(" + slice->ToString() + + ", range=" + symbol::ToString(range) + ")"; + } +}; + +template +ADT_DEFINE_RC(IndexExprAffine, const IndexExprAffineImpl); + +template +struct DisjointUnionImpl { + T lhs; + T rhs; + + bool operator==(const DisjointUnionImpl& other) const { + return (other.lhs == this->lhs) && (other.rhs == this->rhs); + } + + std::string ToString() const { + return std::string() + "IndexExpr.DisjointUnion(" + lhs.ToString() + ", " + + rhs.ToString() + ")"; + } +}; + +template +ADT_DEFINE_RC(DisjointUnion, const DisjointUnionImpl); + +template +using IndexExprBase = std::variant, + IndexExprSlice, + IndexExprAffine, + DisjointUnion>; + +struct IndexExpr : public IndexExprBase { + using IndexExprBase::IndexExprBase; + ADT_DEFINE_VARIANT_METHODS(IndexExprBase); + + std::string ToString() const { + return Match([](const auto& impl) { return impl->ToString(); }); + } +}; + +template +axpr::TypeImpl> GetIndexExprClass(); + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_expr_builtin_functions.h b/paddle/ap/include/index_expr/index_expr_builtin_functions.h new file mode 100644 index 00000000000000..ac3dc0db70edc0 --- /dev/null +++ b/paddle/ap/include/index_expr/index_expr_builtin_functions.h @@ -0,0 +1,441 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_functions.h" +#include "paddle/ap/include/index_expr/index_expr_util.h" +#include "paddle/ap/include/index_expr/valid_index_expr_builder.h" +#include "paddle/ap/include/index_expr/value.h" + +namespace ap::index_expr { + +using adt::Maybe; +using adt::Result; + +template +Result MakePtrGetItem(const Val&, const std::vector& args); + +template +Result MakeIndexExprBroadcastMask(const Val&, + const std::vector& args); + +template +Result MakeSlice(const Val&, const std::vector& args); + +template +Result MakeIndexExprSlice(const Val&, const std::vector& args); + +template +Result MakeIndexExprAffine(const Val&, const std::vector& args); + +template +Result MakeDisjointUnion(const Val&, const std::vector& args); + +template +Result MakeIndexTupleExprPermute(const Val&, const std::vector& args); + +template +Result MakeIndexTupleExprReshape(const Val&, const std::vector& args); + +template +Result MakeIndexTupleExprTransform(axpr::InterpreterBase* interpreter, + const Val& obj, + const std::vector& args); + +template +Result MakeOpIndexTupleExprSignature(const Val&, + const std::vector& args); + +template +Result MakeInIndexTupleExprSignature(const Val&, + const std::vector& args); + +template +Result MakeOutIndexTupleExprSignature(const Val&, + const std::vector& args); + +template +inline Maybe TryGetImplIndexExprValue(const Val& val) { + const auto& ret = val.template TryGet(); + if (ret.template HasOkValue()) { + return ret.GetOkValue(); + } + return adt::Nothing{}; +} + +template +inline adt::Result TryGetDimExpr(const Val& val) { + using RetT = adt::Result; + return val.Match([](int64_t c) -> RetT { return symbol::DimExpr{c}; }, + [](const axpr::BuiltinClassInstance& instance) -> RetT { + return instance.template TryGet(); + }, + [&](const auto&) -> RetT { + return adt::errors::TypeError{ + "TryGetDimExpr() failed. argument 1 should an int or " + "DimExpr (not " + + axpr::GetTypeName(val) + ")"}; + }); +} + +template +inline Maybe TryGetInt64(const Val& val) { + return val.Match( + [](int64_t c) -> Maybe { return c; }, + [](const symbol::DimExpr& dim_expr) -> Maybe { + return dim_expr.Match( + [](const int64_t c) -> Maybe { return c; }, + [](const auto&) -> Maybe { return adt::Nothing{}; }); + }, + [&](const auto&) -> Maybe { return adt::Nothing{}; }); +} + +template +Result MakePtrGetItem(const Val&, const std::vector& args) { + if (args.size() != 3) { + return adt::errors::TypeError{ + std::string("PtrGetItem takes 3 arguments but ") + + std::to_string(args.size()) + "were given."}; + } + const auto& opt_arg1 = TryGetImplIndexExprValue(args.at(1)); + ADT_LET_CONST_REF(dim_expr, TryGetDimExpr(args.at(2))); + return std::visit(::common::Overloaded{ + [&](const std::string& ptr_var_name, + const IndexTupleExpr& indexes_expr) -> Result { + return PtrGetItem{ + ptr_var_name, + std::make_shared(indexes_expr), + dim_expr}; + }, + [&](const auto&, const auto&) -> Result { + return adt::errors::InvalidArgumentError{ + "wrong argument type for PtrGetItem"}; + }}, + args.at(0).variant(), + opt_arg1.variant()); +} + +namespace detail { + +template +Result ConvertResult(const T& result) { + return result.Match([](const auto& impl) -> Result { return impl; }); +} + +} // namespace detail + +template +Result MakeIndexExprBroadcastMask(const Val&, + const std::vector& args) { + if (args.size() != 2) { + return adt::errors::TypeError{ + std::string("IndexExprBroadcastMask takes 2 arguments but ") + + std::to_string(args.size()) + "were given."}; + } + ADT_LET_CONST_REF(dim_expr, TryGetDimExpr(args.at(0))); + const auto& opt_arg1 = TryGetImplIndexExprValue(args.at(1)); + ValidIndexExprBuilder builder{}; + const auto& pattern_match = ::common::Overloaded{ + [&](const IndexExpr& index_expr) -> Result { + return detail::ConvertResult( + builder.BroadcastMask(dim_expr, index_expr)); + }, + [&](const auto&) -> Result { + return adt::errors::InvalidArgumentError{ + "wrong argument type for IndexExprBroadcastMask"}; + }}; + return std::visit(pattern_match, opt_arg1.variant()); +} + +template +Result MakeSlice(const Val&, const std::vector& args) { + if (args.size() != 3) { + return adt::errors::TypeError{std::string("Slice takes 3 arguments but ") + + std::to_string(args.size()) + "were given."}; + } + ADT_LET_CONST_REF(start, TryGetDimExpr(args.at(0))); + ADT_LET_CONST_REF(stop, TryGetDimExpr(args.at(1))); + ADT_LET_CONST_REF(step, TryGetDimExpr(args.at(1))); + return Val{Slice{start, stop, step}}; +} + +template +Result MakeIndexExprSlice(const Val&, const std::vector& args) { + if (args.size() != 3) { + return adt::errors::TypeError{ + std::string("IndexExprSlice takes 3 arguments but ") + + std::to_string(args.size()) + "were given."}; + } + const auto& opt_slice = TryGetImplIndexExprValue(args.at(0)); + ADT_LET_CONST_REF(range, TryGetDimExpr(args.at(1))); + const auto& opt_index_expr = TryGetImplIndexExprValue(args.at(2)); + const auto& pattern_match = ::common::Overloaded{ + [](const Slice& slice, const IndexExpr& expr) -> Result { + ValidIndexExprBuilder builder{}; + return detail::ConvertResult(builder.Slice(slice, range, expr)); + }, + [](const auto&, const auto&) -> Result { + return adt::errors::InvalidArgumentError{ + "wrong argument type for IndexExprSlice"}; + }}; + return std::visit( + pattern_match, opt_slice.variant(), opt_index_expr.variant()); +} + +template +Result MakeIndexExprAffine(const Val&, const std::vector& args) { + if (args.size() != 3) { + return adt::errors::TypeError{ + std::string("IndexExprAffine takes 3 arguments but ") + + std::to_string(args.size()) + "were given."}; + } + const auto& opt_slice = TryGetImplIndexExprValue(args.at(0)); + ADT_LET_CONST_REF(range, TryGetDimExpr(args.at(1))); + const auto& opt_index_expr = TryGetImplIndexExprValue(args.at(2)); + return std::visit( + ::common::Overloaded{ + [](const Slice& slice, const IndexExpr& index_expr) -> Result { + ValidIndexExprBuilder builder{}; + return detail::ConvertResult( + builder.Affine(slice, range, index_expr)); + }, + [](const auto&, const auto&) -> Result { + return adt::errors::InvalidArgumentError{ + "wrong argument type for IndexExprAffine"}; + }}, + opt_slice.variant(), + opt_index_expr.variant()); +} + +template +Result MakeDisjointUnion(const Val&, const std::vector& args) { + const auto& opt_lhs = TryGetImplIndexExprValue(args.at(1)); + const auto& opt_rhs = TryGetImplIndexExprValue(args.at(1)); + return std::visit( + ::common::Overloaded{ + [](const IndexExpr& lhs, const IndexExpr& rhs) -> Result { + ValidIndexExprBuilder builder{}; + return detail::ConvertResult(builder.DisjointUnion(lhs, rhs)); + }, + [](const auto&, const auto&) -> Result { + return adt::errors::InvalidArgumentError{ + "wrong argument type for DisjointUnion"}; + }}, + opt_lhs.variant(), + opt_rhs.variant()); +} + +template +inline Maybe> TryGetInt64List(const Val& val) { + return val.Match( + [](const adt::List& l) -> Maybe> { + adt::List ret; + ret->reserve(l->size()); + for (const auto& elt : *l) { + const auto& opt_int = TryGetInt64(elt); + if (!opt_int.template Has()) { + return adt::Nothing{}; + } + ret->push_back(opt_int.template Get()); + } + return ret; + }, + [](const auto&) -> Maybe> { return adt::Nothing{}; }); +} + +template +inline adt::Result> TryGetDimExprList( + const Val& val) { + using RetT = adt::Result>; + ADT_LET_CONST_REF(l, val.template CastTo>()); + adt::List ret; + ret->reserve(l->size()); + for (const auto& elt : *l) { + ADT_LET_CONST_REF(int_val, TryGetDimExpr(elt)); + ret->push_back(int_val); + } + return ret; +} + +template +Result MakeIndexTupleExprPermute(const Val&, + const std::vector& args) { + if (args.size() != 2) { + return adt::errors::TypeError{ + std::string("IndexTupleExprPermute takes 2 arguments but ") + + std::to_string(args.size()) + "were given."}; + } + const auto& opt_perms = TryGetInt64List(args.at(0)); + const auto& opt_expr = TryGetImplIndexExprValue(args.at(1)); + ValidIndexExprBuilder builder{}; + return std::visit( + ::common::Overloaded{ + [&](const adt::List& perms, + const IndexTupleExpr& expr) -> Result { + return detail::ConvertResult(builder.Permute(perms, expr)); + }, + [](const auto&, const auto&) -> Result { + return adt::errors::InvalidArgumentError{ + "wrong argument type for IndexTupleExprPermute"}; + }}, + opt_perms.variant(), + opt_expr.variant()); +} + +template +Result MakeIndexTupleExprReshape(const Val&, + const std::vector& args) { + if (args.size() != 2) { + return adt::errors::TypeError{ + std::string("IndexTupleExprReshape takes 2 arguments but ") + + std::to_string(args.size()) + "were given."}; + } + ADT_LET_CONST_REF(shape, TryGetDimExprList(args.at(0))); + const auto& opt_expr = TryGetImplIndexExprValue(args.at(1)); + ValidIndexExprBuilder builder{}; + return std::visit( + ::common::Overloaded{ + [&](const IndexTupleExpr& expr) -> Result { + return detail::ConvertResult(builder.Reshape(shape, expr)); + }, + [](const auto&) -> Result { + return adt::errors::InvalidArgumentError{ + "wrong argument type for IndexTupleExprReshape"}; + }}, + opt_expr.variant()); +} + +template +Result MakeIndexTupleExprTransform(axpr::InterpreterBase* interpreter, + const Val&, + const std::vector& args) { + if (args.size() < 1) { + return adt::errors::TypeError{ + "IndexTupleExprTransform takes at least 1 argument but 0 were given."}; + } + const auto& opt_expr = TryGetImplIndexExprValue(args.at(0)); + if (!opt_expr.template Has()) { + return adt::errors::TypeError{ + "The first argument of IndexTupleExprTransform must be a " + "IndexTupleExpr."}; + } + const auto& indexes_expr = opt_expr.template Get(); + const auto& opt_rank = IndexTupleExprGetRank(indexes_expr); + if (!opt_rank.template Has()) { + return adt::errors::TypeError{ + "The first argument of IndexTupleExprTransform must be a ranked " + "IndexTupleExpr."}; + } + const auto& opt_dim_exprs = IndexTupleExprGetRanges(indexes_expr); + if (!opt_dim_exprs.template Has>()) { + return adt::errors::RuntimeError{ + "error occurred where calling IndexTupleExprGetDims"}; + } + const auto& dim_exprs = + opt_dim_exprs.template Get>(); + if (opt_rank.template Get() != args.size() - 1) { + return adt::errors::TypeError{ + "The rank of first argument must equal to number of lambdas."}; + } + adt::List transform_index_exprs; + transform_index_exprs->reserve(args.size() - 1); + for (int i = 1; i < args.size(); ++i) { + const auto& opt_closure = args.at(i).template TryGet>(); + ADT_RETURN_IF_ERR(opt_closure); + const auto& closure = opt_closure.GetOkValue(); + + if (closure->lambda->args.size() != 1) { + return adt::errors::TypeError{std::string("Argument ") + + std::to_string(i) + + " is not a single-argumented closure."}; + } + int idx = i - 1; + IndexExprDomain domain{dim_exprs->at(idx)}; + const auto& ret_lambda_call = + interpreter->InterpretCall(closure, {Val{domain}}); + ADT_RETURN_IF_ERR(ret_lambda_call); + const auto& ret_index_expr = + TryGetImplIndexExprValue(ret_lambda_call.GetOkValue()); + if (!ret_index_expr.template Has()) { + return adt::errors::TypeError{std::string("closure of argument") + + std::to_string(i) + + " does not return a IndexExpr."}; + } + transform_index_exprs->push_back(ret_index_expr.template Get()); + } + ValidIndexExprBuilder builder{}; + ADT_LET_CONST_REF(ret, + detail::ConvertResult(builder.Transform( + transform_index_exprs, indexes_expr))); + return ret; +} + +template +Result MakeOpIndexTupleExprSignature(const Val&, + const std::vector& args) { + if (args.size() != 2) { + return adt::errors::TypeError{ + std::string("OpIndexTupleExprSignature takes 2 arguments but ") + + std::to_string(args.size()) + "were given."}; + } + const auto& in_sig = args.at(0); + const auto& opt_in = in_sig.template TryGet(); + ADT_RETURN_IF_ERR(opt_in); + const auto& in = opt_in.GetOkValue(); + const auto& out_sig = args.at(1); + const auto& opt_out = out_sig.template TryGet(); + ADT_RETURN_IF_ERR(opt_out); + const auto& out = opt_out.GetOkValue(); + return OpIndexTupleExprSignature{in, out}; +} + +template +Result MakeInIndexTupleExprSignature(const Val&, + const std::vector& args) { + adt::List indexes_exprs; + indexes_exprs->reserve(args.size()); + for (const auto& arg : args) { + const auto& maybe_indexes_expr = + TryGetImplIndexExprValue(arg); + if (!maybe_indexes_expr.template Has()) { + return adt::errors::InvalidArgumentError{ + "only arguments of `IndexTupleExpr` type is valid for " + "InIndexTupleExprSignature"}; + } + indexes_exprs->push_back(maybe_indexes_expr.template Get()); + } + return InIndexTupleExprSignature{indexes_exprs}; +} + +template +Result MakeOutIndexTupleExprSignature(const Val&, + const std::vector& args) { + adt::List indexes_exprs; + indexes_exprs->reserve(args.size()); + for (const auto& arg : args) { + const auto& maybe_indexes_expr = + TryGetImplIndexExprValue(arg); + if (!maybe_indexes_expr.template Has()) { + return adt::errors::InvalidArgumentError{ + "only arguments of `IndexTupleExpr` type is valid for " + "OutIndexTupleExprSignature"}; + } + indexes_exprs->push_back(maybe_indexes_expr.template Get()); + } + return OutIndexTupleExprSignature{indexes_exprs}; +} + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_expr_interpreter.h b/paddle/ap/include/index_expr/index_expr_interpreter.h new file mode 100644 index 00000000000000..f4f0d845c15801 --- /dev/null +++ b/paddle/ap/include/index_expr/index_expr_interpreter.h @@ -0,0 +1,47 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/axpr/builtin_functions.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/index_expr/index_expr.h" +#include "paddle/ap/include/index_expr/index_expr_builtin_functions.h" +#include "paddle/ap/include/index_expr/value.h" +#include "paddle/ap/include/index_expr/value_method_class.h" + +namespace ap::index_expr { + +class IndexExprInterpreter { + public: + IndexExprInterpreter(); + IndexExprInterpreter(const IndexExprInterpreter&) = delete; + IndexExprInterpreter(IndexExprInterpreter&&) = delete; + + Result operator()(const axpr::Lambda& lambda, + const std::vector& args) const { + return adt::errors::NotImplementedError{ + "IndexExprInterpreter::operator()(lambda, args)"}; + } + + Result operator()( + const std::unordered_map>& + global_functions, + const axpr::Lambda& lambda, + const std::vector& args) const { + return adt::errors::NotImplementedError{ + "IndexExprInterpreter::operator()(global_functions, lambda, args)"}; + } +}; + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_expr_method_class.h b/paddle/ap/include/index_expr/index_expr_method_class.h new file mode 100644 index 00000000000000..ee733851f5f65e --- /dev/null +++ b/paddle/ap/include/index_expr/index_expr_method_class.h @@ -0,0 +1,46 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/index_expr/index_expr.h" +#include "paddle/ap/include/index_expr/index_expr_builtin_functions.h" + +namespace ap::index_expr { + +template +struct IndexExprMethodClass { + using This = IndexExprMethodClass; + using Self = IndexExpr; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return self.ToString(); + } +}; + +template +axpr::TypeImpl> GetIndexExprClass() { + using ImplMethods = IndexExprMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "IndexExpr", + [&](const auto& Define) { Define("__str__", &ImplMethods::ToString); })); + using Self = typename ImplMethods::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_expr_util.h b/paddle/ap/include/index_expr/index_expr_util.h new file mode 100644 index 00000000000000..e054903e5b2660 --- /dev/null +++ b/paddle/ap/include/index_expr/index_expr_util.h @@ -0,0 +1,168 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/index_expr/index_expr.h" +#include "paddle/ap/include/index_expr/index_tuple_expr.h" + +namespace ap::index_expr { + +using adt::Maybe; + +inline Maybe IndexTupleExprGetRank(const IndexTupleExpr& expr) { + return expr.Match( + [](const UndefinedIndexTupleExpr&) -> Maybe { + return adt::Nothing{}; + }, + [](const NothingIndexTupleExpr&) -> Maybe { + return adt::Nothing{}; + }, + [](const IntArrayLikeIndexTupleExpr&) -> Maybe { + return adt::Nothing{}; + }, + [](const IndexTupleExprDomain& domain) -> Maybe { + return domain->ranges->size(); + }, + [](const IndexTupleExprPermute& perm) -> Maybe { + return perm->perms->size(); + }, + [](const IndexTupleExprReshape& reshape) + -> Maybe { return reshape->shape->size(); }, + [](const IndexTupleExprTransform& transform) + -> Maybe { return transform->index_exprs->size(); }); +} + +inline Maybe IndexExprGetRange(const IndexExpr& index_expr) { + return index_expr.Match( + [](const UndefinedIndexExpr&) -> Maybe { + return adt::Nothing{}; + }, + [](const PtrGetItem& ptr_get_item) -> Maybe { + return ptr_get_item->range; + }, + [](const IndexExprDomain& domain) -> Maybe { + return domain->range; + }, + [](const IndexExprBroadcastMask& mask) + -> Maybe { return mask->dim; }, + [](const IndexExprSlice& index_slice) + -> Maybe { return index_slice->range; }, + [](const IndexExprAffine& index_affine) + -> Maybe { return index_affine->range; }, + [](const DisjointUnion& union_expr) -> Maybe { + const auto& opt_lhs_dim_expr = IndexExprGetRange(union_expr->lhs); + const auto& opt_rhs_dim_expr = IndexExprGetRange(union_expr->rhs); + return std::visit( + ::common::Overloaded{ + [](const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) + -> Maybe { return lhs + rhs; }, + [](const auto&, const auto&) -> Maybe { + return adt::Nothing{}; + }}, + opt_lhs_dim_expr.variant(), + opt_rhs_dim_expr.variant()); + }); +} + +inline Maybe IndexExprGetDomain(const IndexExpr& index_expr) { + return index_expr.Match( + [](const UndefinedIndexExpr&) -> Maybe { + return adt::Nothing{}; + }, + [](const PtrGetItem& ptr_get_item) -> Maybe { + return ptr_get_item->range; + }, + [](const IndexExprDomain& domain) -> Maybe { + return domain->range; + }, + [](const IndexExprBroadcastMask& mask) + -> Maybe { + return IndexExprGetDomain(mask->index_expr); + }, + [](const IndexExprSlice& index_slice) + -> Maybe { + return IndexExprGetDomain(index_slice->index_expr); + }, + [](const IndexExprAffine& index_affine) + -> Maybe { + return IndexExprGetDomain(index_affine->index_expr); + }, + [](const DisjointUnion& union_expr) -> Maybe { + const auto& lhs = IndexExprGetDomain(union_expr->lhs); + const auto& rhs = IndexExprGetDomain(union_expr->rhs); + const auto& pattern_match = ::common::Overloaded{ + [&](const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) { + if (lhs == rhs) { + return Maybe{lhs}; + } else { + return Maybe{adt::Nothing{}}; + } + }, + [&](const auto&, const auto&) { + return Maybe{adt::Nothing{}}; + }}; + return std::visit(pattern_match, lhs.variant(), rhs.variant()); + }); +} + +inline Maybe> IndexTupleExprGetRanges( + const IndexTupleExpr& expr) { + return expr.Match( + [](const UndefinedIndexTupleExpr&) -> Maybe> { + return adt::Nothing{}; + }, + [](const NothingIndexTupleExpr&) -> Maybe> { + return adt::Nothing{}; + }, + [](const IntArrayLikeIndexTupleExpr&) + -> Maybe> { return adt::Nothing{}; }, + [](const IndexTupleExprDomain& domain) + -> Maybe> { return domain->ranges; }, + [](const IndexTupleExprPermute& perm) + -> Maybe> { + const auto& opt_origin_dim_exprs = + IndexTupleExprGetRanges(perm->indexes_expr); + if (opt_origin_dim_exprs.Has()) { + return adt::Nothing{}; + } + const auto& origin_dim_exprs = + opt_origin_dim_exprs.Get>(); + adt::List ret; + ret->reserve(perm->perms->size()); + for (const int idx : *perm->perms) { + ret->push_back(origin_dim_exprs->at(idx)); + } + return ret; + }, + [](const IndexTupleExprReshape& reshape) + -> Maybe> { return reshape->shape; }, + [](const IndexTupleExprTransform& transform) + -> Maybe> { + adt::List ret; + ret->reserve(transform->index_exprs->size()); + for (const auto& index_expr : *transform->index_exprs) { + const auto& opt_dim_expr = IndexExprGetRange(index_expr); + if (opt_dim_expr.Has()) { + return adt::Nothing{}; + } + ret->push_back(opt_dim_expr.Get()); + } + return ret; + }); +} + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_tuple_expr.h b/paddle/ap/include/index_expr/index_tuple_expr.h new file mode 100644 index 00000000000000..09c90cd9496cf7 --- /dev/null +++ b/paddle/ap/include/index_expr/index_tuple_expr.h @@ -0,0 +1,196 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/index_expr/index_expr.h" +#include "paddle/ap/include/index_expr/slice.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" + +namespace ap::index_expr { + +struct UndefinedIndexTupleExprImpl : public std::monostate { + using std::monostate::monostate; + + std::string ToString() const { return "IndexTupleExpr.Undefined"; } + + const char* TypeName() const { return "UndefinedIndexTupleExpr"; } +}; +ADT_DEFINE_RC(UndefinedIndexTupleExpr, UndefinedIndexTupleExprImpl); + +struct NothingIndexTupleExprImpl : public std::monostate { + using std::monostate::monostate; + + std::string ToString() const { return "IndexTupleExpr.Nothing"; } + + const char* TypeName() const { return "NothingIndexTupleExpr"; } +}; +ADT_DEFINE_RC(NothingIndexTupleExpr, NothingIndexTupleExprImpl); + +struct IntArrayLikeIndexTupleExprImpl : public std::monostate { + using std::monostate::monostate; + + std::string ToString() const { return "IndexTupleExpr.IntArrayLike"; } + + const char* TypeName() const { return "IntArrayLikeIndexTupleExpr"; } +}; +ADT_DEFINE_RC(IntArrayLikeIndexTupleExpr, IntArrayLikeIndexTupleExprImpl); + +struct IndexTupleExprDomainImpl { + adt::List ranges; + bool operator==(const IndexTupleExprDomainImpl& other) const { + return other.ranges == this->ranges; + } + + std::string ToString() const { + std::ostringstream ss; + ss << "["; + int i = 0; + for (const auto& elt : *ranges) { + if (i++ > 0) { + ss << ", "; + } + ss << symbol::ToString(elt); + } + ss << "]"; + return std::string() + "IndexTupleExpr.Domain(" + ss.str() + ")"; + } + + const char* TypeName() const { return "IndexTupleExprDomain"; } +}; +ADT_DEFINE_RC(IndexTupleExprDomain, const IndexTupleExprDomainImpl); + +template +struct IndexTupleExprPermuteImpl { + adt::List perms; + Expr indexes_expr; + + bool operator==(const IndexTupleExprPermuteImpl& other) const { + return other.perms == this->perms && + other.indexes_expr == this->indexes_expr; + } + + std::string ToString() const { + std::ostringstream ss; + ss << "["; + int i = 0; + for (int64_t perm : *perms) { + if (i++ > 0) { + ss << ", "; + } + ss << perm; + } + ss << "]"; + return indexes_expr.ToString() + ".permute(" + ss.str() + ")"; + } + + const char* TypeName() const { return "IndexTupleExprPermute"; } +}; + +template +ADT_DEFINE_RC(IndexTupleExprPermute, const IndexTupleExprPermuteImpl); + +template +struct IndexTupleExprReshapeImpl { + adt::List shape; + Expr indexes_expr; + + bool operator==(const IndexTupleExprReshapeImpl& other) const { + return other.shape == this->shape && + other.indexes_expr == this->indexes_expr; + } + + std::string ToString() const { + std::ostringstream ss; + ss << "["; + int i = 0; + for (const auto& elt : *shape) { + if (i++ > 0) { + ss << ", "; + } + ss << symbol::ToString(elt); + } + ss << "]"; + return indexes_expr.ToString() + ".reshape(" + ss.str() + ")"; + } + + const char* TypeName() const { return "IndexTupleExprReshape"; } +}; +template +ADT_DEFINE_RC(IndexTupleExprReshape, const IndexTupleExprReshapeImpl); + +template +struct IndexTupleExprTransformImpl { + adt::List index_exprs; + Expr indexes_expr; + + bool operator==(const IndexTupleExprTransformImpl& other) const { + return other.index_exprs == this->index_exprs && + other.indexes_expr == this->indexes_expr; + } + + std::string ToString() const { + std::ostringstream ss; + ss << "["; + int i = 0; + for (const auto& elt : *index_exprs) { + if (i++ > 0) { + ss << ", "; + } + ss << elt.ToString(); + } + ss << "]"; + return indexes_expr.ToString() + ".transform(" + ss.str() + ")"; + } + + const char* TypeName() const { return "IndexTupleExprTransform"; } +}; +template +ADT_DEFINE_RC(IndexTupleExprTransform, const IndexTupleExprTransformImpl); + +template +using IndexTupleExprBase = std::variant, + IndexTupleExprReshape, + IndexTupleExprTransform>; + +struct IndexTupleExpr : public IndexTupleExprBase { + using IndexTupleExprBase::IndexTupleExprBase; + ADT_DEFINE_VARIANT_METHODS(IndexTupleExprBase); + + const char* TypeName() const { + return Match([](const auto& impl) { return impl->TypeName(); }); + } + + std::string ToString() const { + return Match([](const auto& impl) { return impl->ToString(); }); + } +}; + +inline std::string IndexTupleExprToString( + const std::shared_ptr& indexes_expr) { + return indexes_expr->ToString(); +} + +template +axpr::TypeImpl> GetIndexTupleExprClass(); + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_tuple_expr_cuda_code_generator.h b/paddle/ap/include/index_expr/index_tuple_expr_cuda_code_generator.h new file mode 100644 index 00000000000000..a12baa6ac45066 --- /dev/null +++ b/paddle/ap/include/index_expr/index_tuple_expr_cuda_code_generator.h @@ -0,0 +1,97 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/common/unique_id.h" +#include "paddle/ap/include/index_expr/dim_expr_cuda_code_generator.h" +#include "paddle/ap/include/index_expr/index_tuple_expr.h" + +namespace ap::index_expr { + +class IndexTupleExprCudaCodeGenerator { + public: + using ArgName4DimExprT = + std::function(const symbol::DimExpr&)>; + IndexTupleExprCudaCodeGenerator( + std::ostringstream* ss, + const std::vector& loop_var_names, + const ArgName4DimExprT& ArgName4DimExpr) + : ss_(ss), + loop_var_names_(loop_var_names), + index_type_name_("int64_t"), + dim_expr_code_gen_(ss, ArgName4DimExpr, "int64_t") {} + + std::ostringstream& ss() { return *ss_; } + + adt::Result CodeGen(const IndexTupleExpr& indexes_expr) { + return indexes_expr.Match( + [&](const IndexTupleExprDomain& domain) -> adt::Result { + return CodeGenImpl(domain); + }, + [&](const auto& impl) -> adt::Result { + return adt::errors::NotImplementedError{ + std::string() + + "IndexTupleExprCudaCodeGenerator::CodeGen not support " + + impl->TypeName() + " yet."}; + }); + } + + private: + adt::Result CodeGenImpl(const IndexTupleExprDomain& domain) { + const auto& var_name = NewTmpVarName("_ap_i"); + int i = 0; + auto DoEachPair = [&](const auto& iter, + const auto& stride) -> adt::Result { + if (i++ == 0) { + ADT_CHECK(stride == symbol::DimExpr{int64_t(1)}); + ss() << index_type_name_ << " " << var_name << " = " << iter << ";\n"; + } else { + ADT_LET_CONST_REF(stride_var_name, dim_expr_code_gen_.CodeGen(stride)); + ss() << var_name << " += " << iter << " * " << stride_var_name << ";\n"; + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitEachIterAndStride(domain->ranges, DoEachPair)); + return var_name; + } + + template + adt::Result VisitEachIterAndStride( + const adt::List& ranges, const DoEachPairT& DoEachPair) { + symbol::DimExpr stride{int64_t(1)}; + ADT_CHECK(loop_var_names_.size() == ranges->size()); + for (int i = loop_var_names_.size() - 1; i >= 0; --i) { + const auto& iter_var_name = loop_var_names_.at(i); + const auto& dim = ranges->at(i); + ADT_RETURN_IF_ERR(DoEachPair(iter_var_name, stride)); + stride = stride * dim; + } + return adt::Ok{}; + } + + std::string NewTmpVarName(const std::string& prefix) { + return ap::common::NewUniqueId(prefix); + } + + std::ostringstream* ss_; + std::vector loop_var_names_; + std::string index_type_name_; + DimExprCudaCodeGenerator dim_expr_code_gen_; +}; + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_tuple_expr_method_class.h b/paddle/ap/include/index_expr/index_tuple_expr_method_class.h new file mode 100644 index 00000000000000..e59dc88728bb31 --- /dev/null +++ b/paddle/ap/include/index_expr/index_tuple_expr_method_class.h @@ -0,0 +1,89 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/index_expr/index_expr_builtin_functions.h" +#include "paddle/ap/include/index_expr/index_tuple_expr.h" + +namespace ap::index_expr { + +template +struct IndexTupleExprMethodClass { + using This = IndexTupleExprMethodClass; + using Self = IndexTupleExpr; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return self.ToString(); + } +}; + +template +struct TypeImplIndexTupleExprMethodClass { + using This = TypeImplIndexTupleExprMethodClass; + using Self = axpr::TypeImpl; + + static adt::Result StaticConstructIndexTupleExprDomain( + const ValueT&, const std::vector& args) { + return This{}.ConstructIndexTupleExprDomain(args); + } + + adt::Result ConstructIndexTupleExprDomain( + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "'IndexTupleExpr.Domain' takes 1 argument but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(list, args.at(0).template TryGet>()) + << adt::errors::TypeError{std::string() + + "the argument 1 of 'IndexTupleExpr.Domain' " + "should a list of DimExpr."}; + adt::List dim_exprs; + dim_exprs->reserve(list->size()); + for (const auto& arg : *list) { + ADT_LET_CONST_REF(dim_expr, CastToDimExpr(arg)) + << adt::errors::TypeError{std::string() + + "the argument 1 of 'IndexTupleExpr.Domain' " + "should a list of DimExpr."}; + dim_exprs->emplace_back(dim_expr); + } + IndexTupleExpr index_tuple_expr{IndexTupleExprDomain{dim_exprs}}; + axpr::BuiltinClassInstance instance{ + GetIndexTupleExprClass(), index_tuple_expr}; + return instance; + } + + adt::Result CastToDimExpr(const ValueT& val) { + ADT_LET_CONST_REF(dim_expr, TryGetDimExpr(val)); + return dim_expr; + } +}; + +template +axpr::TypeImpl> GetIndexTupleExprClass() { + using TypeImplMethods = TypeImplIndexTupleExprMethodClass; + using ImplMethods = IndexTupleExprMethodClass; + static auto cls( + axpr::MakeBuiltinClass("IndexTupleExpr", [&](const auto& Define) { + Define("Domain", &TypeImplMethods::StaticConstructIndexTupleExprDomain); + Define("__str__", &ImplMethods::ToString); + })); + using Self = typename ImplMethods::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/op_index_tuple_expr_signature.h b/paddle/ap/include/index_expr/op_index_tuple_expr_signature.h new file mode 100644 index 00000000000000..37bc15fa948e85 --- /dev/null +++ b/paddle/ap/include/index_expr/op_index_tuple_expr_signature.h @@ -0,0 +1,44 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/index_expr/index_expr.h" +#include "paddle/ap/include/index_expr/index_tuple_expr.h" +#include "paddle/ap/include/index_expr/op_signature.h" + +namespace ap::index_expr { + +using InIndexTupleExprSignature = InputSignature; +using OutIndexTupleExprSignature = OutputSignature; +using OpIndexTupleExprSignature = OpSignature; + +} // namespace ap::index_expr + +namespace ap::axpr { + +template +axpr::TypeImpl> +GetInIndexTupleExprSignatureClass(); + +template +axpr::TypeImpl> +GetOutIndexTupleExprSignatureClass(); + +template +axpr::TypeImpl> +GetOpIndexTupleExprSignatureClass(); + +} // namespace ap::axpr diff --git a/paddle/ap/include/index_expr/op_index_tuple_expr_signature_method_class.h b/paddle/ap/include/index_expr/op_index_tuple_expr_signature_method_class.h new file mode 100644 index 00000000000000..7459f2ecab9bc1 --- /dev/null +++ b/paddle/ap/include/index_expr/op_index_tuple_expr_signature_method_class.h @@ -0,0 +1,91 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/index_expr/op_index_tuple_expr_signature.h" + +namespace ap::index_expr { + +template +struct InIndexTupleExprSignatureMethodClass { + using Self = index_expr::InIndexTupleExprSignature; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return self->ToString(); + } +}; + +template +axpr::TypeImpl> +GetInIndexTupleExprSignatureClass() { + using ImplMethods = InIndexTupleExprSignatureMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "InIndexTupleExprSignature", + [&](const auto& Define) { Define("__str__", &ImplMethods::ToString); })); + using Self = typename ImplMethods::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +template +struct OutIndexTupleExprSignatureMethodClass { + using Self = index_expr::OutIndexTupleExprSignature; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return self->ToString(); + } +}; + +template +axpr::TypeImpl> +GetOutIndexTupleExprSignatureClass() { + using ImplMethods = OutIndexTupleExprSignatureMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "OutIndexTupleExprSignature", + [&](const auto& Define) { Define("__str__", &ImplMethods::ToString); })); + using Self = typename ImplMethods::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +template +struct OpIndexTupleExprSignatureMethodClass { + using Self = index_expr::OpIndexTupleExprSignature; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return self->ToString(); + } +}; + +template +axpr::TypeImpl> +GetOpIndexTupleExprSignatureClass() { + using ImplMethods = OpIndexTupleExprSignatureMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "OpIndexTupleExprSignature", [&](const auto& Define) { + Define("__str__", + &OpIndexTupleExprSignatureMethodClass::ToString); + })); + using Self = typename ImplMethods::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/op_signature.h b/paddle/ap/include/index_expr/op_signature.h new file mode 100644 index 00000000000000..0c73ec7006e82c --- /dev/null +++ b/paddle/ap/include/index_expr/op_signature.h @@ -0,0 +1,78 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/index_expr/index_expr.h" + +namespace ap::index_expr { + +template +struct InputSignature { + adt::List descriptors; + + std::string ToString() const { + std::ostringstream ss; + int i = 0; + for (const auto& elt : *descriptors) { + if (i++ > 0) { + ss << ", "; + } + ss << elt.ToString(); + } + return std::string() + "InputSignature(" + ss.str() + ")"; + } + + bool operator==(const InputSignature& other) const { + return other.descriptors == this->descriptors; + } +}; + +template +struct OutputSignature { + adt::List descriptors; + + std::string ToString() const { + std::ostringstream ss; + int i = 0; + for (const auto& elt : *descriptors) { + if (i++ > 0) { + ss << ", "; + } + ss << elt.ToString(); + } + return std::string() + "OutputSignature(" + ss.str() + ")"; + } + + bool operator==(const OutputSignature& other) const { + return other.descriptors == this->descriptors; + } +}; + +template +struct OpSignature { + InputSignature in_signature; + OutputSignature out_signature; + + std::string ToString() const { + return std::string() + "OpSignature(" + in_signature.ToString() + ", " + + out_signature.ToString() + ")"; + } + + bool operator==(const OpSignature& other) const { + return other.in_signature == this->in_signature && + other.out_signature == this->out_signature; + } +}; + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/slice.h b/paddle/ap/include/index_expr/slice.h new file mode 100644 index 00000000000000..4e4af6456efbcc --- /dev/null +++ b/paddle/ap/include/index_expr/slice.h @@ -0,0 +1,47 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" + +namespace ap::index_expr { + +struct SliceImpl { + symbol::DimExpr start; + symbol::DimExpr stop; + symbol::DimExpr step; + + bool operator==(const SliceImpl& other) const { + return (other.start == this->start) && (other.stop == this->stop) && + (other.step == this->step); + } + + std::string ToString() const { + return std::string() + "Slice(start=" + symbol::ToString(start) + + ", stop=" + symbol::ToString(stop) + + ", step=" + symbol::ToString(step) + ")"; + } +}; + +ADT_DEFINE_RC(Slice, const SliceImpl); + +template +axpr::TypeImpl> GetSliceClass(); + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/slice_method_class.h b/paddle/ap/include/index_expr/slice_method_class.h new file mode 100644 index 00000000000000..5398457afcbba7 --- /dev/null +++ b/paddle/ap/include/index_expr/slice_method_class.h @@ -0,0 +1,45 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/index_expr/slice.h" + +namespace ap::index_expr { + +template +struct SliceMethodClass { + using This = SliceMethodClass; + using Self = Slice; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return self->ToString(); + } +}; + +template +axpr::TypeImpl> GetSliceClass() { + using ImplMethods = SliceMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "Slice", + [&](const auto& Define) { Define("__str__", &ImplMethods::ToString); })); + using Self = typename ImplMethods::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/valid_index_expr_builder.h b/paddle/ap/include/index_expr/valid_index_expr_builder.h new file mode 100644 index 00000000000000..a2847ec515d9aa --- /dev/null +++ b/paddle/ap/include/index_expr/valid_index_expr_builder.h @@ -0,0 +1,248 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/error.h" +#include "paddle/ap/include/index_expr/index_expr.h" +#include "paddle/ap/include/index_expr/index_expr_util.h" +#include "paddle/ap/include/index_expr/index_tuple_expr.h" +#include "paddle/ap/include/index_expr/slice.h" + +namespace ap::index_expr { + +using adt::Result; + +class ValidIndexExprBuilder { + public: + ValidIndexExprBuilder() {} + ValidIndexExprBuilder(const ValidIndexExprBuilder&) = delete; + ValidIndexExprBuilder(ValidIndexExprBuilder&&) = delete; + + Result BroadcastMask(const symbol::DimExpr& dim_expr, + const IndexExpr& index_expr) { + return IndexExprBroadcastMask{dim_expr, index_expr}; + } + + Result Slice(const ap::index_expr::Slice& slice, + const symbol::DimExpr& range, + const IndexExpr& index_expr) { + return IndexExprSlice{slice, range, index_expr}; + } + + Result Affine(const ap::index_expr::Slice& slice, + const symbol::DimExpr& range, + const IndexExpr& index_expr) { + return IndexExprAffine{slice, range, index_expr}; + } + + Result DisjointUnion(const IndexExpr& lhs_index_expr, + const IndexExpr& rhs_index_expr) { + const auto& lhs_domain = IndexExprGetDomain(lhs_index_expr); + const auto& rhs_domain = IndexExprGetDomain(rhs_index_expr); + const auto& pattern_match = ::common::Overloaded{ + [](const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) { + return lhs == rhs; + }, + [](const auto&, const auto&) { return false; }}; + const bool do_equal = + std::visit(pattern_match, lhs_domain.variant(), rhs_domain.variant()); + if (!do_equal) { + return adt::errors::TypeError{ + "domain of `lhs_index_expr' does not equal to domain of " + "`rhs_index_expr'"}; + } + return index_expr::DisjointUnion{lhs_index_expr, rhs_index_expr}; + } + + Result Permute(const adt::List& perms, + const IndexTupleExpr& indexes_expr) { + if (!IsValidPerm(perms)) { + return adt::errors::InvalidArgumentError{"argument `perms` is not valid"}; + } + const auto& rank = IndexTupleExprGetRank(indexes_expr); + if (!rank.Has()) { + return adt::errors::InvalidArgumentError{ + "wrong indexes_expr argument for IndexTupleExprPermute"}; + } + if (rank.Get() != perms->size()) { + return adt::errors::InvalidArgumentError{std::string( + "the rank of perms does not equal to the rank of " + "indexes_expr. rank(perm): " + + std::to_string(perms->size()) + + ", rank(indexes_expr): " + std::to_string(rank.Get()))}; + } + return IndexTupleExprPermute{perms, indexes_expr}; + } + + Result Reshape(const adt::List& shape, + const IndexTupleExpr& indexes_expr) { + if (ContainsNegative(shape)) { + return adt::errors::InvalidArgumentError{ + "dims in argument `shape` have negative integer"}; + } + const auto& opt_ranges = IndexTupleExprGetRanges(indexes_expr); + if (opt_ranges.Has()) { + return adt::errors::InvalidArgumentError{ + "argument `indexes_expr` is not a ranked IndexTupleExpr"}; + } + if (!ProductEqual(shape, opt_ranges.Get>())) { + return adt::errors::InvalidArgumentError{ + "product of argument `shape` does not equal to elements of " + "`indexes_expr`"}; + } + return IndexTupleExprReshape{shape, indexes_expr}; + } + + Result Transform( + const adt::List& transform_index_exprs, + const IndexTupleExpr& indexes_expr) { + const auto& opt_rank = IndexTupleExprGetRank(indexes_expr); + if (!opt_rank.Has()) { + return adt::errors::TypeError{ + "The first argument of IndexTupleExprTransform must be a ranked " + "IndexTupleExpr."}; + } + const auto& opt_ranges = IndexTupleExprGetRanges(indexes_expr); + if (!opt_ranges.Has>()) { + return adt::errors::RuntimeError{ + "error occurred where calling IndexTupleExprGetDims"}; + } + const auto& ranges = opt_ranges.Get>(); + if (opt_rank.Get() != transform_index_exprs->size()) { + return adt::errors::TypeError{ + "The rank of first argument must equal to number of lambdas."}; + } + adt::List domains{}; + domains->reserve(transform_index_exprs->size()); + for (const auto& index_expr : *transform_index_exprs) { + const auto& domain = IndexExprGetDomain(index_expr); + if (!domain.Has()) { + return adt::errors::TypeError{ + "one of transform_index_exprs has no demain."}; + } + domains->emplace_back(domain.Get()); + } + if (ranges != domains) { + return adt::errors::TypeError{ + "domain of `transform_index_exprs' does not equal to range of " + "`indexes_expr'."}; + } + return IndexTupleExprTransform{transform_index_exprs, + indexes_expr}; + } + + // outer(inner(x)) == (outer . inner)(x) + Result Compose(const IndexTupleExpr& outer, + const IndexTupleExpr& inner) { + return outer.Match( + [&](const UndefinedIndexTupleExpr& impl) -> Result { + return impl; + }, + [&](const NothingIndexTupleExpr& impl) -> Result { + return impl; + }, + [&](const IntArrayLikeIndexTupleExpr& impl) -> Result { + return impl; + }, + [&](const IndexTupleExprDomain& domain) -> Result { + const auto& ranges = IndexTupleExprGetRanges(inner); + if (ranges.Has()) { + return adt::errors::TypeError{"`inner_indexes_expr' has no range."}; + } + if (ranges.Get>() != domain->ranges) { + return adt::errors::TypeError{ + "the domain of `outer_indexes_expr' does not equal to the " + "range " + "of `inner_indexes_expr'."}; + } + return inner; + }, + [&](const IndexTupleExprPermute& perm) + -> Result { + const auto& composed_inner = Compose(perm->indexes_expr, inner); + if (composed_inner.HasError()) { + return composed_inner.GetError(); + } + return Permute(perm->perms, composed_inner.Get()); + }, + [&](const IndexTupleExprReshape& reshape) + -> Result { + const auto& composed_inner = Compose(reshape->indexes_expr, inner); + if (composed_inner.HasError()) { + return composed_inner.GetError(); + } + return Reshape(reshape->shape, composed_inner.Get()); + }, + [&](const IndexTupleExprTransform& transform) + -> Result { + const auto& composed_inner = Compose(transform->indexes_expr, inner); + if (composed_inner.HasError()) { + return composed_inner.GetError(); + } + return Transform(transform->index_exprs, + composed_inner.Get()); + }); + } + + private: + template + bool IsValidPerm(const PermsT& perms) { + std::vector idx2touched(perms->size(), false); + for (int64_t perm : *perms) { + if (perm < 0) { + return false; + } + if (perm >= perms->size()) { + return false; + } + idx2touched[perm] = true; + } + for (bool touched : idx2touched) { + if (!touched) { + return false; + } + } + return true; + } + + template + bool ContainsNegative(const ShapeT& shape) { + for (const auto& dim : *shape) { + if (!dim.template Has()) { + continue; + } + if (dim.template Get() < 0) { + return true; + } + } + return false; + } + + template + symbol::DimExpr Product(const DimExprsT& dim_exprs) { + symbol::DimExpr ret_expr{1}; + for (const auto& dim : *dim_exprs) { + ret_expr = ret_expr * dim; + } + return ret_expr; + } + + bool ProductEqual(const auto& lhs, const auto& rhs) { + return Product(lhs) == Product(rhs); + } +}; + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/value.h b/paddle/ap/include/index_expr/value.h new file mode 100644 index 00000000000000..dbf6061c5e50d2 --- /dev/null +++ b/paddle/ap/include/index_expr/value.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/axpr/builtin_functions.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/dim_expr.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/index_expr/index_expr.h" +#include "paddle/ap/include/index_expr/index_tuple_expr.h" +#include "paddle/ap/include/index_expr/op_index_tuple_expr_signature.h" +#include "paddle/pir/include/core/attribute.h" +#include "paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h" + +namespace ap::index_expr { + +using axpr::Value; + +using Val = Value; + +} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/value_method_class.h b/paddle/ap/include/index_expr/value_method_class.h new file mode 100644 index 00000000000000..dcdf380fc36411 --- /dev/null +++ b/paddle/ap/include/index_expr/value_method_class.h @@ -0,0 +1,22 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/dim_expr_method_class.h" +#include "paddle/ap/include/axpr/value_method_class.h" +#include "paddle/ap/include/index_expr/index_expr_method_class.h" +#include "paddle/ap/include/index_expr/index_tuple_expr_method_class.h" +#include "paddle/ap/include/index_expr/op_index_tuple_expr_signature_method_class.h" +#include "paddle/ap/include/index_expr/slice_method_class.h" diff --git a/paddle/ap/include/ir_match/graph_match_ctx.h b/paddle/ap/include/ir_match/graph_match_ctx.h new file mode 100644 index 00000000000000..499fb7951b9a99 --- /dev/null +++ b/paddle/ap/include/ir_match/graph_match_ctx.h @@ -0,0 +1,280 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/drr_graph_descriptor.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/topo_kind.h" +#include "paddle/ap/include/drr/value.h" +#include "paddle/ap/include/graph/graph_descriptor.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/ir_match/topo_match_ctx.h" + +namespace ap::ir_match { + +template +struct GraphMatchCtxImpl { + using DrrNode = drr::Node; + using DrrNativeIrValue = drr::NativeIrValue; + using DrrPackedIrValue = drr::PackedIrValue; + using sg_node_t = graph::Node; + + TopoMatchCtx topo_match_ctx; + + bool operator==(const GraphMatchCtxImpl& other) const { + return this == &other; + } + + std::size_t num_matched_bg_nodes() const { + return topo_match_ctx->num_matched_bg_nodes(); + } + + adt::Result HasBigGraphNode(const sg_node_t& node) const { + return topo_match_ctx->HasBigGraphNode(node); + } + + adt::Result GetNumBigGraphIrValueNodes( + const sg_node_t& node) const { + std::size_t num = 0; + auto Increase = [&](const auto&) -> adt::Result { + ++num; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitBigGraphIrValueNode(node, Increase)); + return num; + } + + template + adt::Result VisitBigGraphIrValueNode(const sg_node_t& node, + const YieldT& Yield) const { + ADT_LET_CONST_REF(drr_node, node.Get()); + using Ok = adt::Result; + return drr_node.Match( + [&](const DrrNativeIrValue&) -> Ok { + ADT_LET_CONST_REF(bir_node, GetSoleBigGraphNode(node)); + return Yield(bir_node); + }, + [&](const DrrPackedIrValue&) -> Ok { + return VisitPackedBigGraphIrValueNode(node, Yield); + }, + [&](const auto& impl) -> Ok { + using T = std::decay_t; + return adt::errors::NotImplementedError{ + std::string() + + "VisitBigGraphIrValueNode() support DrrNativeIrValue and " + "DrrPackedIrValue only, " + + typeid(T).name() + " found."}; + }); + } + + adt::Result GetSoleBigGraphNode(const sg_node_t& node) const { + return topo_match_ctx->GetSoleBigGraphNode(node); + } + + std::optional GetOptMatchedSmallGraphNode( + const bg_node_t& bg_node) const { + return topo_match_ctx->GetMatchedSmallGraphNode(bg_node); + } + + using DefaultDrrGraph = + graph::GraphDescriptor; + + adt::Result> GetPackedBigGraphIrValueNodes( + const sg_node_t& node) const { + adt::List ret; + using Ok = adt::Result; + auto CollectInput = [&](const bg_node_t& node) -> Ok { + ret->emplace_back(node); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitPackedBigGraphIrValueNode(node, CollectInput)); + return ret; + } + + template + adt::Result VisitPackedBigGraphIrValueNode( + const sg_node_t& node, const YieldT& Yield) const { + ADT_LET_CONST_REF(drr_node, node.Get()); + ADT_CHECK(drr_node.template Has()); + DefaultDrrGraph drr_graph{}; + ADT_LET_CONST_REF(is_ignored, drr_graph.IgnoredNode(node)); + ADT_CHECK(is_ignored); + ADT_LET_CONST_REF(num_inputs, drr_graph.GetNumInputs(drr_node)); + ADT_LET_CONST_REF(num_outputs, drr_graph.GetNumOutputs(drr_node)); + if (num_inputs == 0 && num_outputs == 1) { + return VisitPackedInputBigGraphNode(node, Yield); + } + if (num_inputs == 1 && num_outputs == 0) { + return VisitPackedOutputBigGraphNode(node, Yield); + } + return adt::errors::TypeError{ + std::string() + + "VisitPackedBigGraphIrValueNode() failed. num_inputs: " + + std::to_string(num_inputs) + + ", num_outputs: " + std::to_string(num_outputs)}; + } + + template + adt::Result VisitPackedInputBigGraphNode( + const sg_node_t& packed_ir_value_node, const YieldT& Yield) const { + ADT_LET_CONST_REF(packed_ir_value_drr_node, packed_ir_value_node.Get()); + DefaultDrrGraph drr_graph{}; + ADT_LET_CONST_REF(drr_packed_ir_op_operand_node, + drr_graph.GetSoleOutput(packed_ir_value_drr_node)); + ADT_LET_CONST_REF(drr_packed_ir_op_node, + drr_graph.GetSoleOutput(drr_packed_ir_op_operand_node)); + ADT_LET_CONST_REF( + exclude_bir_native_ir_values, + GetBirNativeIrInputsOfPackedIrOp(drr_packed_ir_op_node.node())); + using Ok = adt::Result; + auto YieldIgnored = [&](const bg_node_t& node) -> Ok { + if (exclude_bir_native_ir_values.count(node) == 0) { + return Yield(node); + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitBirIrInputOfPackedIrOp(drr_packed_ir_op_node.node(), + YieldIgnored)); + return adt::Ok{}; + } + + template + adt::Result VisitPackedOutputBigGraphNode( + const sg_node_t& packed_ir_value_node, const YieldT& Yield) const { + ADT_LET_CONST_REF(packed_ir_value_drr_node, packed_ir_value_node.Get()); + DefaultDrrGraph drr_graph{}; + ADT_LET_CONST_REF(drr_packed_ir_op_result_node, + drr_graph.GetSoleInput(packed_ir_value_drr_node)); + ADT_LET_CONST_REF(drr_packed_ir_op_node, + drr_graph.GetSoleInput(drr_packed_ir_op_result_node)); + ADT_LET_CONST_REF( + exclude_bir_native_ir_values, + GetBirNativeIrOutputsOfPackedIrOp(drr_packed_ir_op_node.node())); + using Ok = adt::Result; + auto YieldIgnored = [&](const bg_node_t& node) -> Ok { + if (exclude_bir_native_ir_values.count(node) == 0) { + return Yield(node); + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitBirIrOutputOfPackedIrOp(drr_packed_ir_op_node.node(), + YieldIgnored)); + return adt::Ok{}; + } + + using DefaultBirGraph = + graph::GraphDescriptor; + + template + adt::Result VisitBirIrInputOfPackedIrOp( + const sg_node_t& drr_packed_ir_op_node, const YieldT& Yield) const { + DefaultBirGraph bir_graph{}; + ADT_LET_CONST_REF(bir_packed_or_ref_ir_op_node, + GetSoleBigGraphNode(drr_packed_ir_op_node)); + using Ok = adt::Result; + auto VisitIrOpOperand = [&](const bg_node_t& node) -> Ok { + return bir_graph.VisitUpstreamNodes(node, Yield); + }; + ADT_RETURN_IF_ERR(bir_graph.VisitUpstreamNodes(bir_packed_or_ref_ir_op_node, + VisitIrOpOperand)); + return adt::Ok{}; + } + + template + adt::Result VisitBirIrOutputOfPackedIrOp( + const sg_node_t& drr_packed_ir_op_node, const YieldT& Yield) const { + DefaultBirGraph bir_graph{}; + ADT_LET_CONST_REF(bir_packed_or_ref_ir_op_node, + GetSoleBigGraphNode(drr_packed_ir_op_node)); + using Ok = adt::Result; + auto VisitIrOpResult = [&](const bg_node_t& node) -> Ok { + return bir_graph.VisitDownstreamNodes(node, Yield); + }; + ADT_RETURN_IF_ERR(bir_graph.VisitDownstreamNodes( + bir_packed_or_ref_ir_op_node, VisitIrOpResult)); + return adt::Ok{}; + } + + adt::Result> GetBirNativeIrInputsOfPackedIrOp( + const sg_node_t& packed_ir_op_node) const { + DefaultDrrGraph drr_graph{}; + std::unordered_set set; + using Ok = adt::Result; + int num_ignored = 0; + auto VisitIrValue = [&](const sg_node_t& node) -> Ok { + ADT_LET_CONST_REF(ignored, drr_graph.IgnoredNode(node)); + if (!ignored) { + ADT_LET_CONST_REF(bir_node, GetSoleBigGraphNode(node)); + set.insert(bir_node); + } else { + ++num_ignored; + } + return adt::Ok{}; + }; + auto VisitIrOpOperand = [&](const sg_node_t& node) -> Ok { + return drr_graph.VisitUpstreamNodes(node, VisitIrValue); + }; + ADT_RETURN_IF_ERR( + drr_graph.VisitUpstreamNodes(packed_ir_op_node, VisitIrOpOperand)); + ADT_CHECK(num_ignored <= 1) << adt::errors::NotImplementedError{ + std::string() + + "multiple packed ir value inputs are not supported yet."}; + return set; + } + + adt::Result> GetBirNativeIrOutputsOfPackedIrOp( + const sg_node_t& packed_ir_op_node) const { + DefaultDrrGraph drr_graph{}; + std::unordered_set set; + using Ok = adt::Result; + int num_ignored = 0; + auto VisitIrValue = [&](const sg_node_t& node) -> Ok { + ADT_LET_CONST_REF(ignored, drr_graph.IgnoredNode(node)); + if (!ignored) { + ADT_LET_CONST_REF(bir_node, GetSoleBigGraphNode(node)); + set.insert(bir_node); + } else { + ++num_ignored; + } + return adt::Ok{}; + }; + auto VisitIrOpResult = [&](const sg_node_t& node) -> Ok { + return drr_graph.VisitDownstreamNodes(node, VisitIrValue); + }; + ADT_RETURN_IF_ERR( + drr_graph.VisitDownstreamNodes(packed_ir_op_node, VisitIrOpResult)); + ADT_CHECK(num_ignored <= 1) << adt::errors::NotImplementedError{ + std::string() + + "multiple packed ir value outputs are not supported yet."}; + return set; + } + + adt::Result GetMatchedSmallGraphNode( + const bg_node_t& bg_node) const { + const auto& sg_node = topo_match_ctx->GetMatchedSmallGraphNode(bg_node); + ADT_CHECK(sg_node.has_value()); + return sg_node.value(); + } + + template + adt::Result VisitSmallGraphNode(const YieldT& Yield) const { + return topo_match_ctx->VisitSmallGraphNode(Yield); + } +}; + +template +ADT_DEFINE_RC(GraphMatchCtx, GraphMatchCtxImpl); + +} // namespace ap::ir_match diff --git a/paddle/ap/include/ir_match/graph_matcher.h b/paddle/ap/include/ir_match/graph_matcher.h new file mode 100644 index 00000000000000..1f9a8af334d279 --- /dev/null +++ b/paddle/ap/include/ir_match/graph_matcher.h @@ -0,0 +1,120 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/topo_kind.h" +#include "paddle/ap/include/drr/value.h" +#include "paddle/ap/include/graph/adt.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/ir_match/graph_match_ctx.h" +#include "paddle/ap/include/ir_match/topo_matcher.h" + +namespace ap::ir_match { + +template +struct GraphMatcher { + using DrrNode = drr::Node; + using DrrNativeIrOp = drr::NativeIrOp; + using sg_node_t = graph::Node; + + TopoMatcher topo_matcher_; + + GraphMatcher(const GraphDescriptor& bg_descriptor, + const GraphDescriptor& sg_descriptor) + : topo_matcher_(bg_descriptor, sg_descriptor) {} + + GraphMatcher(const GraphMatcher&) = delete; + GraphMatcher(GraphMatcher&&) = delete; + + adt::Result> MatchByAnchor( + const bg_node_t& bg_node, const sg_node_t& anchor_node) { + ADT_LET_CONST_REF(topo_match_ctx, + topo_matcher_.MatchByAnchor(bg_node, anchor_node)); + return GraphMatchCtx{topo_match_ctx}; + } + + template + adt::Result VisitMisMatchedNodes( + const GraphMatchCtx& graph_match_ctx, + const sg_node_t& anchor_node, + const DoEachT& DoEach) const { + const auto& topo_match_ctx = graph_match_ctx->topo_match_ctx; + return topo_matcher_.VisitMisMatchedNodes( + topo_match_ctx, anchor_node, DoEach); + } + + adt::Result UpdateByConnectionsUntilDone( + GraphMatchCtx* ctx, const sg_node_t& anchor_node) { + ADT_LET_CONST_REF(new_topo_match_ctx, + Solve((*ctx)->topo_match_ctx, anchor_node)); + (*ctx)->topo_match_ctx = new_topo_match_ctx; + return adt::Ok{}; + } + + adt::Result IsGraphMatched(const GraphMatchCtx& ctx, + const sg_node_t& anchor_node) const { + return topo_matcher_.IsGraphMatched(ctx->topo_match_ctx, anchor_node); + } + + adt::Result HasUndetermined(const GraphMatchCtx& ctx) const { + return topo_matcher_.HasUndetermined(ctx); + } + + template + adt::Result InplaceForcePickOneLastUndetermined( + GraphMatchCtx* ctx, const ReMatchT& ReMatch) const { + return InplaceForcePickOneLastUndetermined( + ctx, ReMatch, /*loop_limit=*/9999); + } + + template + adt::Result InplaceForcePickOneLastUndetermined( + GraphMatchCtx* ctx, + const ReMatchT& ReMatch, + int loop_limit) const { + return topo_matcher_.InplaceForcePickOneLastUndetermined( + ctx, ReMatch, loop_limit); + } + + private: + using TopoMatchCtxT = TopoMatchCtx; + + adt::Result Solve(TopoMatchCtxT topo_match_ctx, + const sg_node_t& anchor_node) { + ADT_RETURN_IF_ERR(topo_matcher_.UpdateByConnectionsUntilDone( + &*topo_match_ctx, anchor_node)); + const auto& opt_iter = topo_match_ctx->GetFirstUnsolved(); + if (!opt_iter.has_value()) { + return topo_match_ctx; + } + const auto& unsolved_sg_node = opt_iter.value()->first; + for (const auto& proprosal_bg_node : opt_iter.value()->second) { + ADT_LET_CONST_REF(impl, + topo_match_ctx->CloneAndSetUnsolved(unsolved_sg_node, + proprosal_bg_node)); + TopoMatchCtxT proprosal_topo_match_ctx{impl}; + ADT_LET_CONST_REF(solved, Solve(proprosal_topo_match_ctx, anchor_node)); + if (!solved->GetFirstMismatched().has_value()) { + return solved; + } + } + // all proposals failed. + return topo_match_ctx; + } +}; + +} // namespace ap::ir_match diff --git a/paddle/ap/include/ir_match/ir_match_ctx.h b/paddle/ap/include/ir_match/ir_match_ctx.h new file mode 100644 index 00000000000000..4716e875aae443 --- /dev/null +++ b/paddle/ap/include/ir_match/ir_match_ctx.h @@ -0,0 +1,50 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/source_pattern_ctx.h" +#include "paddle/ap/include/drr/value.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/ir_match/graph_match_ctx.h" +#include "paddle/ap/include/ir_match/op_match_ctx.h" +#include "paddle/ap/include/ir_match/tensor_match_ctx.h" + +namespace ap::ir_match { + +template +struct IrMatchCtxImpl { + using DrrNodeT = drr::Node; + using SmallGraphNodeT = graph::Node; + drr::SourcePatternCtx source_pattern_ctx; + GraphMatchCtx graph_match_ctx; +}; + +template +ADT_DEFINE_RC(IrMatchCtx, IrMatchCtxImpl); + +} // namespace ap::ir_match + +namespace ap::axpr { + +template +struct TypeImpl> : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "IrMatchCtx"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/ir_match/native_or_ref_ir_value.h b/paddle/ap/include/ir_match/native_or_ref_ir_value.h new file mode 100644 index 00000000000000..be955e3b5e6eb9 --- /dev/null +++ b/paddle/ap/include/ir_match/native_or_ref_ir_value.h @@ -0,0 +1,50 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" + +namespace ap::ir_match { + +template +using NativeOrRefIrValueImpl = std::variant; + +template +struct NativeOrRefIrValue : public NativeOrRefIrValueImpl { + using NativeOrRefIrValueImpl::NativeOrRefIrValueImpl; + ADT_DEFINE_VARIANT_METHODS(NativeOrRefIrValueImpl); + + template + static adt::Result CastFrom(const ValueT& val) { + using RetT = adt::Result; + return val.Match( + [](const typename BirNode::native_value_type& impl) -> RetT { + return impl; + }, + [](const typename BirNode::ref_value_type& impl) -> RetT { + return impl; + }, + [](const auto& impl) -> RetT { + using T = std::decay_t; + const char* type_name = typeid(T).name(); + return adt::errors::ValueError{ + std::string() + + "NativeOrRefIrValue::CastFrom failed. type(val): " + type_name}; + }); + } +}; + +} // namespace ap::ir_match diff --git a/paddle/ap/include/ir_match/op_match_ctx.h b/paddle/ap/include/ir_match/op_match_ctx.h new file mode 100644 index 00000000000000..602d4d2e0f850d --- /dev/null +++ b/paddle/ap/include/ir_match/op_match_ctx.h @@ -0,0 +1,39 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::ir_match { + +template +struct IrMatchCtxImpl; + +template +struct OpMatchCtxImpl { + std::weak_ptr> ir_mtach_ctx; + + bool operator==(const OpMatchCtxImpl& other) const { return this == &other; } +}; + +template +ADT_DEFINE_RC(OpMatchCtx, OpMatchCtxImpl); + +template +axpr::TypeImpl> GetOpMatchCtxClass(); + +} // namespace ap::ir_match diff --git a/paddle/ap/include/ir_match/op_match_ctx_method_class.h b/paddle/ap/include/ir_match/op_match_ctx_method_class.h new file mode 100644 index 00000000000000..18e3b8a544affd --- /dev/null +++ b/paddle/ap/include/ir_match/op_match_ctx_method_class.h @@ -0,0 +1,120 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/ir_match/ir_match_ctx.h" +#include "paddle/ap/include/ir_match/op_match_ctx.h" + +namespace ap::ir_match { + +template +struct OpMatchCtxMethodClass { + using This = OpMatchCtxMethodClass; + using Self = ir_match::OpMatchCtx; + + static adt::Result GetAttr(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + const auto& attr_name_val = args.at(0); + ADT_LET_CONST_REF(attr_name, axpr::TryGetImpl(attr_name_val)); + ADT_LET_CONST_REF(ir_op, This{}.GetIrOpByName(self, attr_name)); + if (ir_op.has_value()) { + return ir_op.value(); + } + return adt::errors::TypeError{ + std::string() + "'OpMatchCtx' has no attribute '" + attr_name + "'"}; + } + + using DrrNativeIrOp = drr::NativeIrOp; + using DrrPackedIrOp = drr::PackedIrOp; + using DrrOptPackedIrOp = drr::OptPackedIrOp; + using SmallGraphNodeT = graph::Node; + + using IrNativeIrOp = typename BirNode::native_op_type; + using IrPackedIrOp = typename BirNode::packed_op_type; + using IrRefIrOp = typename BirNode::ref_op_type; + + adt::Result> GetIrOpByName( + const Self& self, const std::string& attr_name) { + ADT_LET_CONST_REF(ir_match_ctx, adt::WeakPtrLock(self->ir_mtach_ctx)); + const auto& source_pattern_ctx = ir_match_ctx->source_pattern_ctx; + const auto& op_pattern_ctx = source_pattern_ctx->op_pattern_ctx; + const auto& iter = op_pattern_ctx->uid2ir_op.find(attr_name); + if (iter == op_pattern_ctx->uid2ir_op.end()) { + return std::nullopt; + } + auto GetIrOpBySmallGraphNode = + [&](const SmallGraphNodeT& node) -> adt::Result { + const auto& graph_match_ctx = ir_match_ctx->graph_match_ctx; + return graph_match_ctx->GetSoleBigGraphNode(node); + }; + ADT_LET_CONST_REF( + ir_node, + iter->second.Match( + [&](const DrrNativeIrOp& native_ir_op) -> adt::Result { + return GetIrOpBySmallGraphNode(native_ir_op->node); + }, + [&](const DrrPackedIrOp& packed_ir_op) -> adt::Result { + return GetIrOpBySmallGraphNode(packed_ir_op->node); + }, + [&](const DrrOptPackedIrOp& packed_ir_op) -> adt::Result { + return GetIrOpBySmallGraphNode(packed_ir_op->node); + }, + [&](const auto&) -> adt::Result { + return adt::errors::ValueError{ + std::string() + "Failed to get OpMatchCtx attribute, '" + + attr_name + "' is a unbounded op which should not be."}; + })); + ADT_LET_CONST_REF( + ir_op, + ir_node.Match( + [&](const IrNativeIrOp& impl) -> adt::Result { + axpr::BuiltinClassInstance instance{ + impl.template GetBuiltinClass(), impl}; + return ValueT{instance}; + }, + [&](const IrPackedIrOp& impl) -> adt::Result { + axpr::BuiltinClassInstance instance{ + impl.template GetBuiltinClass(), impl}; + return ValueT{instance}; + }, + [&](const IrRefIrOp& impl) -> adt::Result { + axpr::BuiltinClassInstance instance{ + impl.template GetBuiltinClass(), impl}; + return ValueT{instance}; + }, + [&](const auto&) -> adt::Result { + return adt::errors::RuntimeError{ + std::string() + + "a ptn op node has wrongly matched to a non-op ir node."}; + })); + return ir_op; + } +}; + +template +axpr::TypeImpl> GetOpMatchCtxClass() { + using Impl = OpMatchCtxMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "OpMatchCtx", + [&](const auto& Define) { Define("__getattr__", &Impl::GetAttr); })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::ir_match diff --git a/paddle/ap/include/ir_match/ref_match_ctx.h b/paddle/ap/include/ir_match/ref_match_ctx.h new file mode 100644 index 00000000000000..b0f3400125a799 --- /dev/null +++ b/paddle/ap/include/ir_match/ref_match_ctx.h @@ -0,0 +1,46 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/ir_match/ref_node_info.h" + +namespace ap::ir_match { + +template +struct RefMatchCtxImpl { + std::unordered_map>> + value2ref_node_info; + std::unordered_map> + operand2node_info; + + adt::Result AddRefNodeInfo( + const RefNodeInfo& node_info) { + auto* vec = &value2ref_node_info[node_info->ir_value]; + vec->emplace_back(node_info); + for (const auto& op_operand : *node_info->op_operands_subset) { + ADT_CHECK(operand2node_info.emplace(op_operand, node_info).second); + } + return adt::Ok{}; + } +}; + +template +ADT_DEFINE_RC(RefMatchCtx, RefMatchCtxImpl); + +} // namespace ap::ir_match diff --git a/paddle/ap/include/ir_match/ref_node_info.h b/paddle/ap/include/ir_match/ref_node_info.h new file mode 100644 index 00000000000000..22928a2d1df5e1 --- /dev/null +++ b/paddle/ap/include/ir_match/ref_node_info.h @@ -0,0 +1,47 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/ap/include/adt/adt.h" + +namespace ap::ir_match { + +template +struct RefNodeInfoImpl { + IrValueT ir_value; + adt::List op_operands_subset; + + bool operator==(const RefNodeInfoImpl& other) const { return this == &other; } +}; + +template +ADT_DEFINE_RC(RefNodeInfo, RefNodeInfoImpl); + +} // namespace ap::ir_match + +namespace std { + +template +struct hash> { + std::size_t operator()( + const ap::ir_match::RefNodeInfo& node) const { + return reinterpret_cast(node.shared_ptr().get()); + } +}; + +} // namespace std diff --git a/paddle/ap/include/ir_match/tags.h b/paddle/ap/include/ir_match/tags.h new file mode 100644 index 00000000000000..36b3c591b34d27 --- /dev/null +++ b/paddle/ap/include/ir_match/tags.h @@ -0,0 +1,21 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace ap::ir_match { + +ADT_DEFINE_TAG(tIsUpstream); + +} diff --git a/paddle/ap/include/ir_match/tensor_match_ctx.h b/paddle/ap/include/ir_match/tensor_match_ctx.h new file mode 100644 index 00000000000000..441be1ccd05383 --- /dev/null +++ b/paddle/ap/include/ir_match/tensor_match_ctx.h @@ -0,0 +1,41 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/type.h" + +namespace ap::ir_match { + +template +struct IrMatchCtxImpl; + +template +struct TensorMatchCtxImpl { + std::weak_ptr> ir_mtach_ctx; + + bool operator==(const TensorMatchCtxImpl& other) const { + return this == &other; + } +}; + +template +ADT_DEFINE_RC(TensorMatchCtx, TensorMatchCtxImpl); + +template +axpr::TypeImpl> GetTensorMatchCtxClass(); + +} // namespace ap::ir_match diff --git a/paddle/ap/include/ir_match/tensor_match_ctx_method_class.h b/paddle/ap/include/ir_match/tensor_match_ctx_method_class.h new file mode 100644 index 00000000000000..5c4c263055b806 --- /dev/null +++ b/paddle/ap/include/ir_match/tensor_match_ctx_method_class.h @@ -0,0 +1,131 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/ir_match/ir_match_ctx.h" +#include "paddle/ap/include/ir_match/tensor_match_ctx.h" + +namespace ap::ir_match { + +template +struct TensorMatchCtxMethodClass { + using This = TensorMatchCtxMethodClass; + using Self = ir_match::TensorMatchCtx; + + static adt::Result GetAttr(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + const auto& attr_name_val = args.at(0); + ADT_LET_CONST_REF(attr_name, axpr::TryGetImpl(attr_name_val)); + ADT_LET_CONST_REF(ir_tensor, This{}.GetIrTensorByName(self, attr_name)); + if (ir_tensor.has_value()) { + return ir_tensor.value(); + } + return adt::errors::TypeError{std::string() + + "'TensorMatchCtx' has no attribute '" + + attr_name + "'"}; + } + + using DrrValueT = drr::Value; + using DrrNodeT = drr::Node; + using DrrNativeIrValue = drr::NativeIrValue; + using DrrPackedIrValue = drr::PackedIrValue; + using SmallGraphNodeT = graph::Node; + + using IrNativeIrValue = typename BirNode::native_value_type; + using IrPackedIrValue = typename BirNode::packed_value_type; + using IrRefIrValue = typename BirNode::ref_value_type; + + adt::Result> GetIrTensorByName( + const Self& self, const std::string& attr_name) { + ADT_LET_CONST_REF(ir_match_ctx, adt::WeakPtrLock(self->ir_mtach_ctx)); + const auto& source_pattern_ctx = ir_match_ctx->source_pattern_ctx; + const auto& tensor_pattern_ctx = source_pattern_ctx->tensor_pattern_ctx; + const auto& iter = tensor_pattern_ctx->uid2ir_value.find(attr_name); + if (iter == tensor_pattern_ctx->uid2ir_value.end()) { + return std::nullopt; + } + using RetT = adt::Result; + auto GetNativeIrValueBySmallGraphNode = + [&](const SmallGraphNodeT& node) -> RetT { + const auto& graph_match_ctx = ir_match_ctx->graph_match_ctx; + ADT_LET_CONST_REF(bir_value_node, + graph_match_ctx->GetSoleBigGraphNode(node)); + return CastFromBirValue(bir_value_node); + }; + auto GetPackedIrValuesBySmallGraphNode = + [&](const SmallGraphNodeT& node) -> RetT { + const auto& graph_match_ctx = ir_match_ctx->graph_match_ctx; + ADT_LET_CONST_REF(bir_nodes, + graph_match_ctx->GetPackedBigGraphIrValueNodes(node)); + adt::List ret; + ret->reserve(bir_nodes->size()); + for (const auto& bir_node : *bir_nodes) { + ADT_LET_CONST_REF(elt, CastFromBirValue(bir_node)); + ret->emplace_back(elt); + } + return ret; + }; + ADT_LET_CONST_REF( + ir_value, + iter->second.Match( + [&](const DrrNativeIrValue& native_ir_value) -> RetT { + return GetNativeIrValueBySmallGraphNode(native_ir_value->node); + }, + [&](const DrrPackedIrValue& packed_ir_value) -> RetT { + return GetPackedIrValuesBySmallGraphNode(packed_ir_value->node); + }, + [&](const auto&) -> RetT { + return adt::errors::ValueError{ + std::string() + "Failed to get OpMatchCtx attribute, '" + + attr_name + "' is a unbounded op which should not be."}; + })); + return ir_value; + } + + adt::Result CastFromBirValue(const BirNode& bir_value_node) { + return bir_value_node.Match( + [&](const IrNativeIrValue& impl) -> adt::Result { + axpr::BuiltinClassInstance instance{ + impl.template GetBuiltinClass(), impl}; + return ValueT{instance}; + }, + [&](const IrRefIrValue& impl) -> adt::Result { + axpr::BuiltinClassInstance instance{ + impl.template GetBuiltinClass(), impl}; + return ValueT{instance}; + }, + [&](const auto&) -> adt::Result { + return adt::errors::RuntimeError{ + std::string() + + "a drr op node has wrongly matched to a non-op ir node."}; + }); + } +}; + +template +axpr::TypeImpl> GetTensorMatchCtxClass() { + using ImplMethods = TensorMatchCtxMethodClass; + static auto cls( + axpr::MakeBuiltinClass("TensorMatchCtx", [&](const auto& Define) { + Define("__getattr__", &ImplMethods::GetAttr); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::ir_match diff --git a/paddle/ap/include/ir_match/topo_match_ctx.h b/paddle/ap/include/ir_match/topo_match_ctx.h new file mode 100644 index 00000000000000..8a1c617e7f2ae3 --- /dev/null +++ b/paddle/ap/include/ir_match/topo_match_ctx.h @@ -0,0 +1,236 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/ap/include/graph/adt.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_descriptor.h" + +namespace ap::ir_match { + +template +struct TopoMatchCtxImpl { + TopoMatchCtxImpl() {} + TopoMatchCtxImpl(const TopoMatchCtxImpl&) = default; + TopoMatchCtxImpl(TopoMatchCtxImpl&&) = default; + + bool operator==(const TopoMatchCtxImpl& other) const { + return this == &other; + } + + std::size_t num_matched_bg_nodes() const { + return matched_bg_node2sg_node_.size(); + } + + bool HasBigGraphNode(const sg_node_t& node) const { + return this->sg_node2bg_nodes_.count(node) > 0; + } + + adt::Result GetSoleBigGraphNode(const sg_node_t& node) const { + ADT_LET_CONST_REF(bg_nodes, GetBigGraphNodes(node)); + ADT_CHECK(bg_nodes->size(), 1); + return *bg_nodes->begin(); + } + + std::optional GetMatchedSmallGraphNode( + const bg_node_t& bg_node) const { + const auto& iter = matched_bg_node2sg_node_.find(bg_node); + if (iter == matched_bg_node2sg_node_.end()) { + return std::nullopt; + } + return iter->second; + } + + bool HasMatchedSmallGraphNode(const bg_node_t& bg_node) const { + const auto& iter = matched_bg_node2sg_node_.find(bg_node); + return iter != matched_bg_node2sg_node_.end(); + } + + adt::Result*> GetBigGraphNodes( + const sg_node_t& node) const { + const auto& iter = this->sg_node2bg_nodes_.find(node); + if (iter == this->sg_node2bg_nodes_.end()) { + return adt::errors::KeyError{ + std::string() + "no node_id " + + graph::NodeDescriptor{}.DebugId(node) + " found."}; + } + return &iter->second; + } + + adt::Result*> MutBigGraphNodes( + const sg_node_t& node) const { + const auto& iter = this->sg_node2bg_nodes_.find(node); + if (iter == this->sg_node2bg_nodes_.end()) { + return adt::errors::KeyError{ + std::string() + "no node_id " + + graph::NodeDescriptor{}.DebugId(node) + " found."}; + } + return const_cast*>(&iter->second); + } + + adt::Result InitBigGraphNodes(const sg_node_t& sg_node, + const std::list& val) { + VLOG(0) << "InitBigGraphNodes. sg_node: " + << graph::NodeDescriptor{}.DebugId(sg_node) + << ", val:" << + [&] { + std::ostringstream ss; + for (const auto& val_node : val) { + ss << graph::NodeDescriptor{}.DebugId(val_node) << " "; + } + return ss.str(); + }(); + auto* ptr = &this->sg_node2bg_nodes_[sg_node]; + ADT_CHECK(ptr->empty()) << adt::errors::KeyError{ + "InitBigGraphNodes failed. 'sg_node' has been matched to existed " + "bg_nodes"}; + ADT_CHECK(!val.empty()) << adt::errors::MismatchError{ + "TopoMatchCtxImpl::InitBigGraphNodes: sg_node should not be matched to " + "empty."}; + for (const auto& bg_node : val) { + ADT_CHECK(!HasMatchedSmallGraphNode(bg_node)) << adt::errors::KeyError{ + "TopoMatchCtxImpl::InitBigGraphNodes failed. there is matched " + "bg_node in 'val'"}; + } + *ptr = val; + if (ptr->size() == 1) { + ADT_CHECK(matched_bg_node2sg_node_.emplace(*val.begin(), sg_node).second); + } + return adt::Ok{}; + } + + adt::Result UpdateBigGraphNodes( + const sg_node_t& sg_node, const std::unordered_set& val) { + ADT_CHECK(!val.empty()); + for (const auto& bg_node : val) { + const auto& opt_matched = GetMatchedSmallGraphNode(bg_node); + ADT_CHECK(!opt_matched.has_value() || opt_matched.value() == sg_node) + << adt::errors::KeyError{ + "UpdateBigGraphNodes failed. there is matched bg_node in " + "'val'"}; + } + auto* ptr = &this->sg_node2bg_nodes_[sg_node]; + VLOG(0) << "UpdateBigGraphNodes: sg_node: " + << graph::NodeDescriptor{}.DebugId(sg_node) + << ", old_val:" << + [&] { + std::ostringstream ss; + for (const auto& val_node : *ptr) { + ss << graph::NodeDescriptor{}.DebugId(val_node) << " "; + } + return ss.str(); + }(); + VLOG(0) << "UpdateBigGraphNodes: sg_node: " + << graph::NodeDescriptor{}.DebugId(sg_node) + << ", arg_val:" << + [&] { + std::ostringstream ss; + for (const auto& val_node : val) { + ss << graph::NodeDescriptor{}.DebugId(val_node) << " "; + } + return ss.str(); + }(); + for (auto lhs_iter = ptr->begin(); lhs_iter != ptr->end();) { + if (val.count(*lhs_iter) > 0) { + ++lhs_iter; + } else { + lhs_iter = ptr->erase(lhs_iter); + } + } + VLOG(0) << "UpdateBigGraphNodes: sg_node: " + << graph::NodeDescriptor{}.DebugId(sg_node) + << ", new_val: " << + [&] { + std::ostringstream ss; + for (const auto& val_node : *ptr) { + ss << graph::NodeDescriptor{}.DebugId(val_node) << " "; + } + return ss.str(); + }(); + if (ptr->size() == 1) { + const auto& iter = + matched_bg_node2sg_node_.emplace(*ptr->begin(), sg_node).first; + ADT_CHECK(iter->second == sg_node); + } + return adt::Ok{}; + } + + adt::Result> CloneAndSetUnsolved( + const sg_node_t& sg_node, const bg_node_t& bg_node) const { + auto ret = std::make_shared(*this); + ret->matched_bg_node2sg_node_[bg_node] = sg_node; + const auto& iter = ret->sg_node2bg_nodes_.find(sg_node); + ADT_CHECK(iter != ret->sg_node2bg_nodes_.end()); + ADT_CHECK(iter->second.size() > 1); + ret->sg_node2bg_nodes_[sg_node] = std::list{bg_node}; + return ret; + } + + using SgNode2BgNodes = std::unordered_map>; + + std::optional GetFirstUnsolved() + const { + for (auto iter = sg_node2bg_nodes_.begin(); iter != sg_node2bg_nodes_.end(); + ++iter) { + if (iter->second.size() > 1) { + return iter; + } + } + return std::nullopt; + } + + std::optional GetFirstMismatched() + const { + for (auto iter = sg_node2bg_nodes_.begin(); iter != sg_node2bg_nodes_.end(); + ++iter) { + if (iter->second.empty()) { + return iter; + } + } + return std::nullopt; + } + + template + adt::Result VisitSmallGraphNode(const YieldT& Yield) const { + for (const auto& [sg_node, _] : sg_node2bg_nodes_) { + ADT_RETURN_IF_ERR(Yield(sg_node)); + } + return adt::Ok{}; + } + + template + adt::Result LoopMutBigGraphNode(const YieldT& Yield) { + for (auto& [_, bg_nodes] : sg_node2bg_nodes_) { + ADT_LET_CONST_REF(ctrl, Yield(&bg_nodes)); + if (ctrl.template Has()) { + break; + } + } + return adt::Ok{}; + } + + private: + SgNode2BgNodes sg_node2bg_nodes_; + std::unordered_map matched_bg_node2sg_node_; +}; + +template +ADT_DEFINE_RC(TopoMatchCtx, TopoMatchCtxImpl); + +} // namespace ap::ir_match diff --git a/paddle/ap/include/ir_match/topo_matcher.h b/paddle/ap/include/ir_match/topo_matcher.h new file mode 100644 index 00000000000000..1c82fe17f1e38d --- /dev/null +++ b/paddle/ap/include/ir_match/topo_matcher.h @@ -0,0 +1,398 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "glog/logging.h" +#include "paddle/ap/include/drr/topo_kind.h" +#include "paddle/ap/include/graph/adt.h" +#include "paddle/ap/include/graph/graph_descriptor.h" +#include "paddle/ap/include/graph/graph_helper.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/graph/node_arena.h" +#include "paddle/ap/include/ir_match/tags.h" +#include "paddle/ap/include/ir_match/topo_match_ctx.h" + +namespace ap::ir_match { + +using graph::GraphDescriptor; +using graph::GraphHelper; + +template +struct TopoMatcher { + TopoMatcher(const GraphDescriptor& bg_descriptor, + const GraphDescriptor& sg_descriptor) + : bg_descriptor_(bg_descriptor), sg_descriptor_(sg_descriptor) {} + + TopoMatcher(const TopoMatcher&) = delete; + TopoMatcher(TopoMatcher&&) = delete; + + adt::Result> MatchByAnchor( + const bg_node_t& bg_node, const sg_node_t& anchor_node) { + ADT_LET_CONST_REF(topo_match_ctx, + MakeTopoMatchCtxFromAnchor(bg_node, anchor_node)); + ADT_RETURN_IF_ERR(UpdateByConnectionsUntilDone( + &*topo_match_ctx.shared_ptr(), anchor_node)); + return topo_match_ctx; + } + + adt::Result UpdateByConnectionsUntilDone( + TopoMatchCtxImpl* ctx, + const sg_node_t& anchor_node) { + size_t kDeadloopDetectionSize = 999999; + while (true) { + ADT_LET_CONST_REF(updated, UpdateAllByConnections(ctx, anchor_node)); + if (!updated) { + break; + } + if (--kDeadloopDetectionSize <= 0) { + return adt::errors::RuntimeError{"Dead loop detected."}; + } + } + return adt::Ok{}; + } + + template + adt::Result VisitMisMatchedNodes( + const TopoMatchCtx& ctx, + const sg_node_t& anchor_node, + const DoEachT& DoEach) const { + auto DoEachSGNode = [&](const sg_node_t& sg_node) -> adt::Result { + if (!ctx->HasBigGraphNode(sg_node)) { + return DoEach(sg_node); + } + return adt::Ok{}; + }; + adt::BfsWalker bfs_walker = + GraphHelper(sg_descriptor_).GetBfsWalker(); + return bfs_walker(anchor_node, DoEachSGNode); + } + + adt::Result IsGraphMatched( + const TopoMatchCtx& ctx, + const sg_node_t& anchor_node) const { + adt::BfsWalker bfs_walker = + GraphHelper(sg_descriptor_).GetBfsWalker(); + std::size_t num_sg_nodes = 0; + auto AccNumSgNodes = [&](const sg_node_t& sg_node) -> adt::Result { + ADT_CHECK(ctx->HasBigGraphNode(sg_node)) + << adt::errors::MismatchError{"IsGraphMatched: sg_node not matched."}; + ADT_LET_CONST_REF(bg_nodes, ctx->GetBigGraphNodes(sg_node)); + ADT_CHECK(bg_nodes->size() == 1) << adt::errors::MismatchError{ + "IsGraphMatched: more than 1 bg_nodes matched to one sg_node."}; + ++num_sg_nodes; + return adt::Ok{}; + }; + const auto& ret = bfs_walker(anchor_node, AccNumSgNodes); + if (ret.HasError()) { + ADT_CHECK(ret.GetError().template Has()) + << ret.GetError(); + return false; + } + return num_sg_nodes == ctx->num_matched_bg_nodes(); + } + + adt::Result NumUndeterminedNodes( + const GraphMatchCtx& ctx) const { + std::size_t num_undetermined_nodes = 0; + using LoopCtrl = adt::Result; + auto Acc = [&](auto* lst) -> LoopCtrl { + num_undetermined_nodes += (lst->size() > 1); + return adt::Continue{}; + }; + ADT_RETURN_IF_ERR( + ctx->topo_match_ctx.shared_ptr()->LoopMutBigGraphNode(Acc)); + return num_undetermined_nodes; + } + + adt::Result*>> + MutFirstUndeterminedBigGraphNodes(GraphMatchCtx* ctx) const { + std::optional*> ret; + using LoopCtrl = adt::Result; + auto Find = [&](auto* lst) -> LoopCtrl { + if (lst->size() > 1) { + ret = lst; + return adt::Break{}; + } + return adt::Continue{}; + }; + ADT_RETURN_IF_ERR( + (*ctx)->topo_match_ctx.shared_ptr()->LoopMutBigGraphNode(Find)); + return ret; + } + + template + adt::Result InplaceForcePickOneLastUndetermined( + GraphMatchCtx* ctx, + const RematchT& Rematch, + int loop_limit) const { + while (true) { + if (--loop_limit < 0) { + return adt::errors::TypeError{ + "dead loop detected in InplaceForcePickOneLastUndetermined()"}; + } + ADT_LET_CONST_REF(num_undetermined_nodes, NumUndeterminedNodes(*ctx)); + if (num_undetermined_nodes == 0) { + return adt::Ok{}; + } + if (num_undetermined_nodes == 1) { + break; + } + ADT_LET_CONST_REF(opt_lst, MutFirstUndeterminedBigGraphNodes(ctx)); + ADT_CHECK(opt_lst.has_value()); + ADT_CHECK(opt_lst.value()->size() > 1); + opt_lst.value()->resize(1); + ADT_LET_CONST_REF(ctrl, Rematch(ctx)); + if (ctrl.template Has()) { + return adt::Ok{}; + } + } + ADT_LET_CONST_REF(opt_lst, MutFirstUndeterminedBigGraphNodes(ctx)); + ADT_CHECK(opt_lst.has_value()); + ADT_CHECK(opt_lst.value()->size() > 1); + std::list candidate_lst(*opt_lst.value()); + opt_lst.value()->resize(1); + for (const auto& node : candidate_lst) { + *opt_lst.value()->begin() = node; + ADT_LET_CONST_REF(ctrl, Rematch(ctx)); + if (ctrl.template Has()) { + return adt::Ok{}; + } + } + return adt::Ok{}; + } + + private: + adt::Result> MakeTopoMatchCtxFromAnchor( + const bg_node_t& bg_node, const sg_node_t& anchor_node) { + TopoMatchCtx match_ctx{}; + const auto& ptn_bfs_walker = + GraphHelper(sg_descriptor_).GetBfsWalker(); + auto InitMatchCtx = [&](const sg_node_t& sg_node) -> adt::Result { + if (sg_node == anchor_node) { + std::list bg_nodes{bg_node}; + ADT_RETURN_IF_ERR(match_ctx->InitBigGraphNodes(anchor_node, bg_nodes)); + } else { + ADT_RETURN_IF_ERR(TopoMatchCtxInitNode(&*match_ctx, sg_node)); + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(ptn_bfs_walker(anchor_node, InitMatchCtx)); + return match_ctx; + } + + adt::Result UpdateAllByConnections( + TopoMatchCtxImpl* match_ctx, + const sg_node_t& anchor_node) { + const auto& ptn_bfs_walker = + GraphHelper(sg_descriptor_).GetBfsWalker(); + bool updated = false; + auto Update = [&](const sg_node_t& sg_node) -> adt::Result { + // no need to update anchor_node. + if (anchor_node == sg_node) { + return adt::Ok{}; + } + if (match_ctx->HasBigGraphNode(sg_node)) { + ADT_LET_CONST_REF(current_updated, + UpdateByConnections(match_ctx, sg_node)); + updated = updated || current_updated; + } else { + ADT_RETURN_IF_ERR(TopoMatchCtxInitNode(match_ctx, sg_node)); + updated = true; + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(ptn_bfs_walker(anchor_node, Update)); + return updated; + } + + adt::Result UpdateByConnections( + TopoMatchCtxImpl* ctx, const sg_node_t& sg_node) { + ADT_LET_CONST_REF(bg_nodes_ptr, ctx->GetBigGraphNodes(sg_node)); + const size_t old_num_bg_nodes = bg_nodes_ptr->size(); + auto Update = [&](const sg_node_t& nearby_node, + tIsUpstream is_upstream) -> adt::Result { + ADT_LET_CONST_REF(bg_nodes, + GetMatchedBigGraphNodesFromConnected( + *ctx, sg_node, nearby_node, is_upstream)); + ADT_CHECK(!bg_nodes.empty()) << adt::errors::RuntimeError{ + std::string() + "small_graph_node: " + + graph::NodeDescriptor{}.DebugId(sg_node) + + ", old_big_graph_nodes: " + GetNodesDebugIds(bg_nodes_ptr) + + ", nearby_node: " + + graph::NodeDescriptor{}.DebugId(nearby_node) + + ", is_nearby_node_from_upstream: " + + std::to_string(is_upstream.value())}; + ADT_RETURN_IF_ERR(ctx->UpdateBigGraphNodes(sg_node, bg_nodes)); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(ForEachInitedUpstream(*ctx, sg_node, Update)); + ADT_RETURN_IF_ERR(ForEachInitedDownstream(*ctx, sg_node, Update)); + return old_num_bg_nodes != bg_nodes_ptr->size(); + } + + std::string GetNodesDebugIds(const std::list* nodes) const { + std::ostringstream ss; + int i = 0; + for (const auto& node : *nodes) { + if (i++ > 0) { + ss << " "; + } + ss << graph::NodeDescriptor{}.DebugId(node); + } + return ss.str(); + } + + adt::Result TopoMatchCtxInitNode( + TopoMatchCtxImpl* ctx, const sg_node_t& sg_node) { + ADT_CHECK(!ctx->HasBigGraphNode(sg_node)); + bool inited = false; + auto InitOrUpdate = + [&](const sg_node_t& node, + tIsUpstream is_upstream) -> adt::Result { + if (!inited) { + ADT_LET_CONST_REF(bg_nodes, + GetInitialMatchedBigGraphNodesFromConnected( + *ctx, sg_node, node, is_upstream)); + ADT_RETURN_IF_ERR(ctx->InitBigGraphNodes(sg_node, bg_nodes)); + inited = (bg_nodes.size() > 0); + } else { + ADT_LET_CONST_REF(bg_nodes, + GetMatchedBigGraphNodesFromConnected( + *ctx, sg_node, node, is_upstream)); + ADT_RETURN_IF_ERR(ctx->UpdateBigGraphNodes(sg_node, bg_nodes)); + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(ForEachInitedUpstream(*ctx, sg_node, InitOrUpdate)); + ADT_RETURN_IF_ERR(ForEachInitedDownstream(*ctx, sg_node, InitOrUpdate)); + ADT_CHECK(inited) << adt::errors::MismatchError{ + "sg_node not successfully inited."}; + return adt::Ok{}; + } + + adt::Result> GetInitialMatchedBigGraphNodesFromConnected( + const TopoMatchCtxImpl& ctx, + const sg_node_t& sg_node, + const sg_node_t& from_node, + tIsUpstream is_from_node_upstream) { + std::list bg_nodes; + const auto& DoEachMatched = + [&](const bg_node_t& bg_node) -> adt::Result { + bg_nodes.emplace_back(bg_node); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitMatchedBigGraphNodesFromConnected( + ctx, sg_node, from_node, is_from_node_upstream, DoEachMatched)); + return bg_nodes; + } + + adt::Result> + GetMatchedBigGraphNodesFromConnected( + const TopoMatchCtxImpl& ctx, + const sg_node_t& sg_node, + const sg_node_t& from_node, + tIsUpstream is_from_node_upstream) { + std::unordered_set bg_nodes; + const auto& DoEachMatched = + [&](const bg_node_t& bg_node) -> adt::Result { + bg_nodes.insert(bg_node); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitMatchedBigGraphNodesFromConnected( + ctx, sg_node, from_node, is_from_node_upstream, DoEachMatched)); + return bg_nodes; + } + + template + adt::Result VisitMatchedBigGraphNodesFromConnected( + const TopoMatchCtxImpl& ctx, + const sg_node_t& sg_node, + const sg_node_t& from_node, + tIsUpstream is_from_node_upstream, + const DoEachT& DoEach) { + ADT_LET_CONST_REF(sg_node_topo_cstr, + sg_descriptor_.GetSmallGraphNodeTopoCstr(sg_node)); + const auto& VisitBigGraphNode = + [&](const bg_node_t& bg_node) -> adt::Result { + ADT_LET_CONST_REF(topo_matched, + bg_descriptor_.TopoSatisfy(bg_node, sg_node_topo_cstr)); + bool matched = topo_matched; + if (matched) { + ap::graph::NodeDescriptor node_descriptor{}; + ADT_LET_CONST_REF( + attrs_matched, + node_descriptor.AttrsSatisfyIfBothAreOpsOrValues(bg_node, sg_node)); + matched = attrs_matched; + } + if (!matched) { + return adt::Ok{}; + } + const auto& opt_matched_sg_node = ctx.GetMatchedSmallGraphNode(bg_node); + if (!opt_matched_sg_node.has_value() || + opt_matched_sg_node.value() == sg_node) { + return DoEach(bg_node); + } + return adt::Ok{}; + }; + ADT_LET_CONST_REF(from_bg_nodes_ptr, ctx.GetBigGraphNodes(from_node)); + for (const bg_node_t& from_bg_node : *from_bg_nodes_ptr) { + if (is_from_node_upstream.value()) { + ADT_RETURN_IF_ERR(bg_descriptor_.VisitDownstreamNodes( + from_bg_node, VisitBigGraphNode)); + } else { + ADT_RETURN_IF_ERR( + bg_descriptor_.VisitUpstreamNodes(from_bg_node, VisitBigGraphNode)); + } + } + return adt::Ok{}; + } + + template + adt::Result ForEachInitedUpstream( + const TopoMatchCtxImpl& ctx, + const sg_node_t& sg_node, + const DoEachT& DoEach) { + auto Visit = [&](const sg_node_t& src) -> adt::Result { + if (ctx.HasBigGraphNode(src)) { + return DoEach(src, tIsUpstream{true}); + } + return adt::Ok{}; + }; + return sg_descriptor_.VisitUpstreamNodes(sg_node, Visit); + } + + template + adt::Result ForEachInitedDownstream( + const TopoMatchCtxImpl& ctx, + const sg_node_t& sg_node, + const DoEachT& DoEach) { + auto Visit = [&](const sg_node_t& dst) -> adt::Result { + if (ctx.HasBigGraphNode(dst)) { + return DoEach(dst, tIsUpstream{false}); + } + return adt::Ok{}; + }; + return sg_descriptor_.VisitDownstreamNodes(sg_node, Visit); + } + + GraphDescriptor bg_descriptor_; + GraphDescriptor sg_descriptor_; +}; + +} // namespace ap::ir_match diff --git a/paddle/ap/include/kernel_dispatch/ap_unary_kernel.h b/paddle/ap/include/kernel_dispatch/ap_unary_kernel.h new file mode 100644 index 00000000000000..cfa05c436c0b2d --- /dev/null +++ b/paddle/ap/include/kernel_dispatch/ap_unary_kernel.h @@ -0,0 +1,41 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/kernel_dispatch/device_ctx.h" + +namespace phi { + +class DenseTensor; + +} + +namespace ap::kernel_dispatch { + +adt::Result ApUnaryKernel( + const DeviceCtx& device_ctx, + const std::vector& xs, + int num_outputs, + const std::string& kernel_define_lambda, + const std::string& infer_meta_lambda, + const std::string& kernel_dispatch_lambda, + const std::string& kernel_dispatch_const_data_lambda, + std::vector outs); + +} // namespace ap::kernel_dispatch diff --git a/paddle/ap/include/kernel_dispatch/arg_value.h b/paddle/ap/include/kernel_dispatch/arg_value.h new file mode 100644 index 00000000000000..98c304553f5959 --- /dev/null +++ b/paddle/ap/include/kernel_dispatch/arg_value.h @@ -0,0 +1,36 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/kernel_dispatch/typed_buffer.h" +#include "paddle/ap/include/rt_module/arg_value.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +class DenseTensor; + +} + +namespace ap::kernel_dispatch { + +using code_module::ArgType; + +using rt_module::ArgValue; + +using rt_module::CastToArgValue; + +} // namespace ap::kernel_dispatch diff --git a/paddle/ap/include/kernel_dispatch/builtin_frame_util.h b/paddle/ap/include/kernel_dispatch/builtin_frame_util.h new file mode 100644 index 00000000000000..e3f3915b11ac35 --- /dev/null +++ b/paddle/ap/include/kernel_dispatch/builtin_frame_util.h @@ -0,0 +1,41 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_frame_util.h" +#include "paddle/ap/include/kernel_dispatch/const_tensor_method_class.h" +#include "paddle/ap/include/kernel_dispatch/dispatch_ctx_method_class.h" +#include "paddle/ap/include/kernel_dispatch/mutable_tensor_method_class.h" + +namespace ap::kernel_dispatch { + +template +void VisitEachBuiltinFrameAttr(const DoEachT& DoEach) { + // Do Nothing. +} + +template +axpr::AttrMap MakeBuiltinFrameAttrMap() { + axpr::AttrMap attr_map; + auto Insert = [&](const std::string& k, const ValueT& v) { + attr_map->Set(k, v); + }; + axpr::VisitEachBuiltinFrameAttr(Insert); + VisitEachBuiltinFrameAttr(Insert); + return attr_map; +} + +} // namespace ap::kernel_dispatch diff --git a/paddle/ap/include/kernel_dispatch/const_tensor.h b/paddle/ap/include/kernel_dispatch/const_tensor.h new file mode 100644 index 00000000000000..b5c68d33ba2818 --- /dev/null +++ b/paddle/ap/include/kernel_dispatch/const_tensor.h @@ -0,0 +1,73 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/code_module/data_type.h" +#include "paddle/ap/include/kernel_dispatch/typed_buffer.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +class DenseTensor; + +} + +namespace ap::kernel_dispatch { + +using ConstTensorDataImpl = std::variant; +struct ConstTensorData : public ConstTensorDataImpl { + using ConstTensorDataImpl::ConstTensorDataImpl; + ADT_DEFINE_VARIANT_METHODS(ConstTensorDataImpl); + + template + const T* data() const { + return Match( + [](const phi::DenseTensor* tensor) -> const T* { + return reinterpret_cast(tensor->data()); + }, + [](const TypedBuffer& buffer) -> const T* { + return reinterpret_cast(buffer->buffer); + }); + } + const void* data() const { + return Match( + [](const phi::DenseTensor* tensor) -> const void* { + return tensor->data(); + }, + [](const TypedBuffer& buffer) -> const void* { + return buffer->buffer; + }); + } + phi::DataType dtype() const { + return Match([](const phi::DenseTensor* tensor) { return tensor->dtype(); }, + [](const TypedBuffer& buffer) { return buffer->dtype; }); + } +}; + +template +struct ConstTensorImpl { + ConstTensorData tensor_data; + adt::List dims; + + bool operator==(const ConstTensorImpl& other) const { + return other.tensor_data == this->tensor_data && other.dims == this->dims; + } +}; + +template +ADT_DEFINE_RC(ConstTensor, ConstTensorImpl); + +} // namespace ap::kernel_dispatch diff --git a/paddle/ap/include/kernel_dispatch/const_tensor_method_class.h b/paddle/ap/include/kernel_dispatch/const_tensor_method_class.h new file mode 100644 index 00000000000000..3d8f102521892e --- /dev/null +++ b/paddle/ap/include/kernel_dispatch/const_tensor_method_class.h @@ -0,0 +1,110 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/kernel_dispatch/const_tensor.h" + +namespace ap::kernel_dispatch { + +using ap::axpr::BuiltinBinaryFunc; +using ap::axpr::BuiltinFuncType; +using ap::axpr::BuiltinUnaryFunc; +using ap::axpr::CppDataType; +using ap::axpr::CppPointerType; +using ap::axpr::DataType; +using ap::axpr::Method; +using ap::axpr::MethodClass; +using ap::axpr::PointerType; +using ap::axpr::PointerValue; + +namespace detail { + +template +Result ConstTensorShapeGetAttr(const ConstTensor& tensor, + const std::string&) { + return tensor->dims; +} + +template +const T* GetConstTensorDataPtr(const ap::axpr::CppDataType&, + const ConstTensorData& tensor) { + return tensor.template data(); +} + +template +Result ConstTensorDataGetAttr(const ConstTensor& tensor, + const std::string&) { + phi::DataType dtype = tensor->tensor_data.dtype(); + const auto& data_type = ap::axpr::GetDataTypeFromPhiDataType(dtype); + ADT_RETURN_IF_ERR(data_type); + return data_type.GetOkValue().Match( + [&](const adt::Undefined&) -> Result { + return TypeError{"dtype is invalid."}; + }, + [&](const auto& impl) -> Result { + return PointerValue{GetConstTensorDataPtr(impl, tensor->tensor_data)}; + }); +} + +template +using ConstTensorGetAttrT = Result (*)(const ConstTensor& tensor, + const std::string&); + +template +Result TensorGetAttr(const ConstTensor& tensor, + const std::string& name) { + static const std::unordered_map> map{ + {"shape", &ConstTensorShapeGetAttr}, + {"data_ptr", &ConstTensorDataGetAttr}, + }; + const auto& iter = map.find(name); + if (iter == map.end()) { + return AttributeError{std::string("'Tensor' has no attribute '") + name + + "'"}; + } + return iter->second(tensor, name); +} + +} // namespace detail + +template +struct ConstTensorMethodClass { + using Self = ConstTensor; + + static adt::Result GetAttr(const ValueT& obj_val, + const std::vector& args) { + ADT_CHECK(args.size() == 1); + const auto& attr_name_val = args.at(0); + ADT_LET_CONST_REF(obj, axpr::Get>(obj_val)); + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + return detail::TensorGetAttr(obj, attr_name); + } +}; + +template +axpr::TypeImpl> GetConstTensorClass() { + using ImplMethods = ConstTensorMethodClass; + static auto cls( + axpr::MakeBuiltinClass("ConstTensor", [&](const auto& DoEach) { + DoEach("__getattr__", &ImplMethods::GetAttr); + })); + using Self = typename ImplMethods::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::kernel_dispatch diff --git a/paddle/ap/include/kernel_dispatch/device_ctx.h b/paddle/ap/include/kernel_dispatch/device_ctx.h new file mode 100644 index 00000000000000..3d937072fa19ce --- /dev/null +++ b/paddle/ap/include/kernel_dispatch/device_ctx.h @@ -0,0 +1,36 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/pointer_value.h" + +namespace ap::kernel_dispatch { + +class DeviceCtxImpl { + public: + virtual ~DeviceCtxImpl() {} + + virtual adt::Result GetStreamAddrAsVoidPtr() = 0; + + bool operator==(const DeviceCtxImpl& other) const { return this == &other; } + + protected: + DeviceCtxImpl() {} +}; + +ADT_DEFINE_RC(DeviceCtx, DeviceCtxImpl); + +} // namespace ap::kernel_dispatch diff --git a/paddle/ap/include/kernel_dispatch/device_ctx_method_class.h b/paddle/ap/include/kernel_dispatch/device_ctx_method_class.h new file mode 100644 index 00000000000000..feb6258a30e857 --- /dev/null +++ b/paddle/ap/include/kernel_dispatch/device_ctx_method_class.h @@ -0,0 +1,25 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/kernel_dispatch/device_ctx.h" + +namespace ap::kernel_dispatch { + +axpr::TypeImpl> GetDeviceCtxClass(); + +} diff --git a/paddle/ap/include/kernel_dispatch/dispatch_ctx.h b/paddle/ap/include/kernel_dispatch/dispatch_ctx.h new file mode 100644 index 00000000000000..3b49dbe5c39505 --- /dev/null +++ b/paddle/ap/include/kernel_dispatch/dispatch_ctx.h @@ -0,0 +1,46 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_serializable_attr_map.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_module/arg_type.h" +#include "paddle/ap/include/code_module/data_type.h" +#include "paddle/ap/include/kernel_dispatch/dispatch_raw_ctx.h" +#include "paddle/ap/include/kernel_dispatch/typed_buffer.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +class DenseTensor; + +} + +namespace ap::kernel_dispatch { + +template +struct DispatchCtxImpl { + DispatchRawCtx raw_ctx; + + axpr::AttrMap kernel_dispatch_const_data; + + bool operator==(const DispatchCtxImpl& other) const { return &other == this; } +}; + +template +ADT_DEFINE_RC(DispatchCtx, DispatchCtxImpl); + +} // namespace ap::kernel_dispatch diff --git a/paddle/ap/include/kernel_dispatch/dispatch_ctx_method_class.h b/paddle/ap/include/kernel_dispatch/dispatch_ctx_method_class.h new file mode 100644 index 00000000000000..349dad45dfce9d --- /dev/null +++ b/paddle/ap/include/kernel_dispatch/dispatch_ctx_method_class.h @@ -0,0 +1,246 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/axpr/value_method_class.h" +#include "paddle/ap/include/kernel_dispatch/device_ctx_method_class.h" +#include "paddle/ap/include/kernel_dispatch/dispatch_ctx.h" +#include "paddle/ap/include/rt_module/function_method_class.h" + +namespace ap::kernel_dispatch { + +using ap::axpr::BuiltinBinaryFunc; +using ap::axpr::BuiltinFuncType; +using ap::axpr::BuiltinUnaryFunc; +using ap::axpr::CppDataType; +using ap::axpr::CppPointerType; +using ap::axpr::DataType; +using ap::axpr::DataValue; +using ap::axpr::MethodClass; +using ap::axpr::PointerType; +using ap::axpr::PointerValue; + +namespace detail { + +template +Result DispatchCtxGetInputs(const DispatchCtx& ctx, + const std::string& attr_name) { + return ctx->raw_ctx->inputs; +} + +template +Result DispatchCtxGetOutputs(const DispatchCtx& ctx, + const std::string& attr_name) { + return ctx->raw_ctx->outputs; +} + +template +Result DispatchCtxGetDeviceCtx(const DispatchCtx& ctx, + const std::string& attr_name) { + return GetDeviceCtxClass().New(ctx->raw_ctx->device_ctx); +} + +template +Result> GetKernelArgs(const Val& args) { + const Result>& arg_list = + args.template TryGet>(); + ADT_RETURN_IF_ERR(arg_list); + adt::List ret; + ret->reserve(arg_list.GetOkValue()->size()); + for (const auto& arg : *arg_list.GetOkValue()) { + const Result& arg_value = CastToArgValue(arg); + ADT_RETURN_IF_ERR(arg_value); + ret->emplace_back(arg_value.GetOkValue()); + } + return ret; +} + +template +Result LaunchCuda(const Val& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 4) << TypeError{ + std::string() + + "DispatchCtx.launch_cuda take 6 arguments (including self) but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(ctx, axpr::Get>(self_val)); + ADT_LET_CONST_REF(func_name, args.at(0).template TryGet()); + ADT_LET_CONST_REF(num_blocks, args.at(1).template TryGet()); + ADT_LET_CONST_REF(num_threads, args.at(2).template TryGet()); + ADT_LET_CONST_REF(kernel_args, GetKernelArgs(args.at(3))); + ADT_RETURN_IF_ERR(ctx->raw_ctx->LaunchCudaKernel( + func_name, num_blocks, num_threads, kernel_args)); + return adt::Nothing{}; +} + +template BuiltinFunc> +Result MakeDispatchCtxMethod(const DispatchCtx& ctx, + const std::string&) { + return ap::axpr::Method{ctx, BuiltinFuncType{BuiltinFunc}}; +} + +template +Result DispatchCtxType(const DispatchCtx& ctx, const std::string&) { + return ap::axpr::TypeImpl{}; +} + +template +using KernelCtxGettAttrT = Result (*)(const DispatchCtx& ctx, + const std::string&); + +template +Result DispatchCtxGetAttr(const DispatchCtx& ctx, + const std::string& name) { + static const std::unordered_map> map{ + {ap::axpr::TypeImpl{}.Name(), + &DispatchCtxType}, + {"inputs", &DispatchCtxGetInputs}, + {"outputs", &DispatchCtxGetOutputs}, + {"device_ctx", &DispatchCtxGetDeviceCtx}, + }; + const auto& iter = map.find(name); + if (iter == map.end()) { + return AttributeError{std::string("'DispatchCtx' has no attribute '") + + name + "'"}; + } + return iter->second(ctx, name); +} + +} // namespace detail + +template +struct DispatchCtxMethodClass { + using This = DispatchCtxMethodClass; + using Self = DispatchCtx; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << ""; + return ss.str(); + } + + static adt::Result GetAttr(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, axpr::Get(self_val)); + ADT_CHECK(args.size() == 1); + const auto& attr_name_val = args.at(0); + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + if (attr_name == "kernel_dispatch_const_data") { + return self->kernel_dispatch_const_data; + } + return detail::DispatchCtxGetAttr(self, attr_name); + } + + static adt::Result StaticGetInputIndexByName( + const ValueT& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, axpr::Get(self_val)); + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + + "'DispatchCtx.get_input_index_by_name' takes 1 argument but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(tensor_name, args.at(0).template TryGet()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of 'DispatchCtx.get_input_index_by_name' should " + "be str (not '" + + axpr::GetTypeName(args.at(0)) + "')."}; + return This{}.GetInputIndexByName(self, tensor_name); + } + + adt::Result GetInputIndexByName(const Self& self, + const std::string& tensor_name) { + const auto& data = self->kernel_dispatch_const_data; + ADT_LET_CONST_REF( + name2idx, + data->template TryGet>( + "__builtin_ap_kernel_input_name_to_index")); + ADT_LET_CONST_REF(index, name2idx->template TryGet(tensor_name)); + return index; + } + + static adt::Result StaticGetOutputIndexByName( + const ValueT& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, axpr::Get(self_val)); + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + + "'DispatchCtx.get_output_index_by_name' takes 1 argument but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(tensor_name, args.at(0).template TryGet()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of 'DispatchCtx.get_output_index_by_name' " + "should be str (not '" + + axpr::GetTypeName(args.at(0)) + "')."}; + return This{}.GetOutputIndexByName(self, tensor_name); + } + + static adt::Result StaticGetSoFunction( + const ValueT& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, axpr::Get(self_val)); + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + + "'DispatchCtx.get_so_function()' takes 1 argument but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(function_name, args.at(0).template TryGet()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of 'DispatchCtx.get_so_function()' " + "should be str (not '" + + axpr::GetTypeName(args.at(0)) + "')."}; + ADT_LET_CONST_REF( + rt_module, + self->raw_ctx->rt_module + .template TryGet>()); + ADT_LET_CONST_REF(function, rt_module->Get(function_name)) + << adt::errors::TypeError{ + std::string() + + "DispatchCtx.get_so_function() failed. so function '" + + function_name + "' not found"}; + return rt_module::GetSoFunctionClass().New(function); + } + + adt::Result GetOutputIndexByName(const Self& self, + const std::string& tensor_name) { + const auto& data = self->kernel_dispatch_const_data; + ADT_LET_CONST_REF( + name2idx, + data->template TryGet>( + "__builtin_ap_kernel_output_name_to_index")); + ADT_LET_CONST_REF(index, name2idx->template TryGet(tensor_name)); + return index; + } +}; + +template +axpr::TypeImpl> GetDispatchCtxClass() { + using Methods = DispatchCtxMethodClass; + static auto cls( + axpr::MakeBuiltinClass("DispatchCtx", [&](const auto& Yield) { + Yield("__str__", &Methods::ToString); + Yield("__getattr__", &Methods::GetAttr); + Yield("get_input_index_by_name", &Methods::StaticGetInputIndexByName); + Yield("get_output_index_by_name", &Methods::StaticGetOutputIndexByName); + Yield("get_so_function", &Methods::StaticGetSoFunction); + })); + using Self = typename Methods::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::kernel_dispatch diff --git a/paddle/ap/include/kernel_dispatch/dispatch_raw_ctx.h b/paddle/ap/include/kernel_dispatch/dispatch_raw_ctx.h new file mode 100644 index 00000000000000..3b177bd5e7270b --- /dev/null +++ b/paddle/ap/include/kernel_dispatch/dispatch_raw_ctx.h @@ -0,0 +1,73 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/code_module/data_type.h" +#include "paddle/ap/include/kernel_dispatch/arg_value.h" +#include "paddle/ap/include/kernel_dispatch/device_ctx.h" +#include "paddle/ap/include/kernel_dispatch/typed_buffer.h" +#include "paddle/ap/include/rt_module/module.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +class DenseTensor; + +} + +namespace ap::kernel_dispatch { + +using RtModuleImpl = std::variant>; + +struct RtModule : public RtModuleImpl { + using RtModuleImpl::RtModuleImpl; + ADT_DEFINE_VARIANT_METHODS(RtModuleImpl); +}; + +template +struct DispatchRawCtxImpl { + DeviceCtx device_ctx; + adt::List inputs; + adt::List outputs; + RtModule rt_module; + + bool operator==(const DispatchRawCtxImpl& other) const { + return &other == this; + } + + Result LaunchCudaKernel( + const std::string& func_name, + int64_t num_blocks, + int64_t num_threads, + const adt::List& kernel_args) const; +}; + +template +ADT_DEFINE_RC(DispatchRawCtx, DispatchRawCtxImpl); + +} // namespace ap::kernel_dispatch + +namespace ap::axpr { + +template +struct TypeImpl> + : public std::monostate { + using value_type = ap::kernel_dispatch::DispatchRawCtx; + + const char* Name() const { return "DispatchRawCtx"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/kernel_dispatch/mutable_tensor.h b/paddle/ap/include/kernel_dispatch/mutable_tensor.h new file mode 100644 index 00000000000000..04ab28e24e586a --- /dev/null +++ b/paddle/ap/include/kernel_dispatch/mutable_tensor.h @@ -0,0 +1,68 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/code_module/data_type.h" +#include "paddle/ap/include/kernel_dispatch/typed_buffer.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +class DenseTensor; + +} + +namespace ap::kernel_dispatch { + +using MutableTensorDataImpl = std::variant; +struct MutableTensorData : public MutableTensorDataImpl { + using MutableTensorDataImpl::MutableTensorDataImpl; + ADT_DEFINE_VARIANT_METHODS(MutableTensorDataImpl); + + template + T* data() const { + return Match( + [](phi::DenseTensor* tensor) -> T* { + return reinterpret_cast(tensor->data()); + }, + [](const TypedBuffer& buffer) -> T* { + return reinterpret_cast(buffer->buffer); + }); + } + void* data() const { + return Match( + [](phi::DenseTensor* tensor) -> void* { return tensor->data(); }, + [](const TypedBuffer& buffer) -> void* { return buffer->buffer; }); + } + phi::DataType dtype() const { + return Match([](phi::DenseTensor* tensor) { return tensor->dtype(); }, + [](const TypedBuffer& buffer) { return buffer->dtype; }); + } +}; + +template +struct MutableTensorImpl { + MutableTensorData tensor_data; + adt::List dims; + + bool operator==(const MutableTensorImpl& other) const { + return other.tensor_data == this->tensor_data && other.dims == this->dims; + } +}; +template +ADT_DEFINE_RC(MutableTensor, MutableTensorImpl); + +} // namespace ap::kernel_dispatch diff --git a/paddle/ap/include/kernel_dispatch/mutable_tensor_method_class.h b/paddle/ap/include/kernel_dispatch/mutable_tensor_method_class.h new file mode 100644 index 00000000000000..b70bf6f8a0daa5 --- /dev/null +++ b/paddle/ap/include/kernel_dispatch/mutable_tensor_method_class.h @@ -0,0 +1,110 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/kernel_dispatch/mutable_tensor.h" + +namespace ap::kernel_dispatch { + +using ap::axpr::BuiltinBinaryFunc; +using ap::axpr::BuiltinFuncType; +using ap::axpr::BuiltinUnaryFunc; +using ap::axpr::CppDataType; +using ap::axpr::CppPointerType; +using ap::axpr::DataType; +using ap::axpr::DataValue; +using ap::axpr::Method; +using ap::axpr::MethodClass; +using ap::axpr::PointerType; +using ap::axpr::PointerValue; + +namespace detail { + +template +Result MutableTensorShapeGetAttr(const MutableTensor& tensor, + const std::string&) { + return tensor->dims; +} + +template +T* GetMutableTensorDataPtr(const ap::axpr::CppDataType&, + const MutableTensorData& tensor) { + return tensor.template data(); +} + +template +Result MutableTensorDataGetAttr(const MutableTensor& tensor, + const std::string&) { + phi::DataType dtype = tensor->tensor_data.dtype(); + const auto& data_type = ap::axpr::GetDataTypeFromPhiDataType(dtype); + ADT_RETURN_IF_ERR(data_type); + return data_type.GetOkValue().Match( + [&](const adt::Undefined&) -> Result { + return TypeError{"dtype is invalid."}; + }, + [&](const auto& impl) -> Result { + return PointerValue{GetMutableTensorDataPtr(impl, tensor->tensor_data)}; + }); +} + +template +using MutableTensorGetAttrT = Result (*)(const MutableTensor& tensor, + const std::string&); + +template +Result TensorGetAttr(const MutableTensor& tensor, + const std::string& name) { + static const std::unordered_map> map{ + {"shape", &MutableTensorShapeGetAttr}, + {"data_ptr", &MutableTensorDataGetAttr}, + }; + const auto& iter = map.find(name); + if (iter == map.end()) { + return AttributeError{std::string("'Tensor' has no attribute '") + name + + "'"}; + } + return iter->second(tensor, name); +} + +} // namespace detail + +template +struct MutableTensorMethodClass { + using Self = MutableTensor; + + static adt::Result GetAttr(const ValueT& obj_val, + const std::vector& args) { + ADT_CHECK(args.size() == 1); + const auto& attr_name_val = args.at(0); + ADT_LET_CONST_REF(obj, axpr::Get>(obj_val)); + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + return detail::TensorGetAttr(obj, attr_name); + } +}; + +template +axpr::TypeImpl> GetMutableTensorClass() { + using Methods = MutableTensorMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "MutableTensor", + [&](const auto& Yield) { Yield("__getattr__", &Methods::GetAttr); })); + using Self = typename Methods::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::kernel_dispatch diff --git a/paddle/ap/include/kernel_dispatch/typed_buffer.h b/paddle/ap/include/kernel_dispatch/typed_buffer.h new file mode 100644 index 00000000000000..dfe67d0426ed1b --- /dev/null +++ b/paddle/ap/include/kernel_dispatch/typed_buffer.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/code_module/adt.h" +#include "paddle/ap/include/code_module/data_type.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +class DenseTensor; + +} + +namespace ap::kernel_dispatch { + +struct TypedBufferImpl { + void* buffer; + phi::DataType dtype; + size_t size; + + bool operator==(const TypedBufferImpl& other) const { + return other.buffer == this->buffer && other.dtype == this->dtype && + other.size == this->size; + } +}; +ADT_DEFINE_RC(TypedBuffer, TypedBufferImpl); + +} // namespace ap::kernel_dispatch diff --git a/paddle/ap/include/kernel_dispatch/value.h b/paddle/ap/include/kernel_dispatch/value.h new file mode 100644 index 00000000000000..88126a24d3d683 --- /dev/null +++ b/paddle/ap/include/kernel_dispatch/value.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_serializable_attr_map.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_module/data_type.h" +#include "paddle/ap/include/kernel_dispatch/arg_value.h" +#include "paddle/ap/include/kernel_dispatch/const_tensor.h" +#include "paddle/ap/include/kernel_dispatch/dispatch_ctx.h" +#include "paddle/ap/include/kernel_dispatch/mutable_tensor.h" +#include "paddle/ap/include/kernel_dispatch/typed_buffer.h" + +namespace ap::kernel_dispatch { + +using axpr::Value; + +using Val = Value; + +} // namespace ap::kernel_dispatch diff --git a/paddle/ap/include/memory/circlable_ref.h b/paddle/ap/include/memory/circlable_ref.h new file mode 100644 index 00000000000000..fb96d92cf80543 --- /dev/null +++ b/paddle/ap/include/memory/circlable_ref.h @@ -0,0 +1,63 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/memory/circlable_ref_impl.h" +#include "paddle/ap/include/memory/circlable_ref_list.h" + +namespace ap::memory { + +template +class CirclableRef { + public: + CirclableRef(const CirclableRef&) = default; + CirclableRef(CirclableRef&&) = default; + explicit CirclableRef(const std::shared_ptr>& impl) + : impl_(impl) {} + CirclableRef& operator=(const CirclableRef&) = default; + CirclableRef& operator=(CirclableRef&&) = default; + + explicit operator bool() const { return impl_->operator bool(); } + + adt::Result Get() const { return impl_->Get(); } + + adt::Result Mut() const { return impl_->Mut(); } + + adt::Result> shared_ptr() const { + ADT_CHECK(this->operator bool()); + return impl_->shared_ptr(); + } + + bool operator==(const CirclableRef& other) const { + return this->impl_ == other.impl_; + } + + static Derived Make(const std::shared_ptr& ref_list, + const std::shared_ptr& obj) { + auto impl = std::make_shared>(obj); + auto iter = ref_list->AddWeakRef(impl); + const auto& ok = impl->InitWeakRefIterAndList(iter, ref_list); + (void)ok; + return Derived(impl); + } + + protected: + T* raw_ptr() const { return impl_->shared_ptr().get(); } + + std::shared_ptr> impl_; +}; + +} // namespace ap::memory diff --git a/paddle/ap/include/memory/circlable_ref_impl.h b/paddle/ap/include/memory/circlable_ref_impl.h new file mode 100644 index 00000000000000..3a43dffd773559 --- /dev/null +++ b/paddle/ap/include/memory/circlable_ref_impl.h @@ -0,0 +1,62 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/memory/circlable_ref_impl_base.h" +#include "paddle/ap/include/memory/circlable_ref_list.h" + +namespace ap::memory { + +template +class CirclableRefImpl : public CirclableRefImplBase { + public: + explicit CirclableRefImpl(const std::shared_ptr& data) : data_(data) {} + ~CirclableRefImpl() override { EraseIterFromList(); } + + void ClearRef() override { data_.reset(); } + + void EraseIterFromList() { + if (circlable_ref_list_weak_ptr_.has_value() && + weak_ref_iter_.has_value()) { + if (auto ptr = circlable_ref_list_weak_ptr_.value().lock()) { + const auto& ret = ptr->EraseWeakRef(this); + (void)ret; + } + } + } + + explicit operator bool() const { return static_cast(data_); } + + adt::Result Get() const { + const auto* ptr = data_.get(); + ADT_CHECK(ptr != nullptr) << adt::errors::TypeError{ + "ptr is deleted. please check CirclableRefList is alive."}; + return ptr; + } + + adt::Result Mut() { + auto* ptr = data_.get(); + ADT_CHECK(ptr != nullptr) << adt::errors::TypeError{ + "ptr is deleted. please check CirclableRefList is alive."}; + return ptr; + } + + const std::shared_ptr& shared_ptr() const { return data_; } + + private: + std::shared_ptr data_; +}; + +} // namespace ap::memory diff --git a/paddle/ap/include/memory/circlable_ref_impl_base.h b/paddle/ap/include/memory/circlable_ref_impl_base.h new file mode 100644 index 00000000000000..2a002d2fae3bce --- /dev/null +++ b/paddle/ap/include/memory/circlable_ref_impl_base.h @@ -0,0 +1,67 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/memory/circlable_ref_list_base.h" + +namespace ap::memory { + +class CirclableRefImplBase + : public std::enable_shared_from_this { + public: + CirclableRefImplBase(const CirclableRefImplBase&) = delete; + CirclableRefImplBase(CirclableRefImplBase&&) = delete; + virtual ~CirclableRefImplBase() {} + + virtual void ClearRef() = 0; + + using WeakRefList = std::list>; + using CirclableRefListPtr = std::weak_ptr; + + const std::optional& weak_ref_iter() const { + return weak_ref_iter_; + } + + const std::optional& circlable_ref_list_ptr() const { + return circlable_ref_list_weak_ptr_; + } + + adt::Result InitWeakRefIterAndList(WeakRefList::iterator iter, + const CirclableRefListPtr& list) { + ADT_CHECK(!weak_ref_iter_.has_value()); + ADT_CHECK(!circlable_ref_list_weak_ptr_.has_value()); + weak_ref_iter_ = iter; + circlable_ref_list_weak_ptr_ = list; + return adt::Ok{}; + } + + void ClearIterAndRef() { + weak_ref_iter_ = std::nullopt; + circlable_ref_list_weak_ptr_ = std::nullopt; + ClearRef(); + } + + protected: + CirclableRefImplBase() = default; + + std::optional weak_ref_iter_; + std::optional circlable_ref_list_weak_ptr_; +}; + +} // namespace ap::memory diff --git a/paddle/ap/include/memory/circlable_ref_list.h b/paddle/ap/include/memory/circlable_ref_list.h new file mode 100644 index 00000000000000..14f05d551aa8da --- /dev/null +++ b/paddle/ap/include/memory/circlable_ref_list.h @@ -0,0 +1,76 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/memory/circlable_ref_impl_base.h" +#include "paddle/ap/include/memory/circlable_ref_list_base.h" + +namespace ap::memory { + +class CirclableRefList : public CirclableRefListBase { + public: + CirclableRefList(const CirclableRefList&) = delete; + CirclableRefList(CirclableRefList&&) = delete; + ~CirclableRefList() override { + const auto& ok = EraseAllWeakRef(); + (void)ok; + } + + CirclableRefList() {} + + const WeakRefList& weak_refs() const { return weak_refs_; } + + WeakRefList::iterator AddWeakRef( + const std::shared_ptr& ref) override { + std::weak_ptr weak_ref{ref}; + WeakRefList::iterator weak_iter = + weak_refs_.insert(weak_refs_.end(), weak_ref); + return weak_iter; + } + + adt::Result EraseWeakRef( + CirclableRefImplBase* ref) override { + ADT_CHECK(ref->weak_ref_iter().has_value()); + ADT_CHECK(ref->circlable_ref_list_ptr().has_value()); + const auto& weak_ptr = ref->circlable_ref_list_ptr().value(); + if (!weak_ptr.expired()) { + ADT_CHECK(weak_ptr.lock().get() == this); + } else { + // called by my own destructor. + ADT_CHECK(weak_ptr.lock().get() == nullptr); + } + auto iter = weak_refs_.erase(ref->weak_ref_iter().value()); + ref->ClearIterAndRef(); + return iter; + } + + private: + adt::Result EraseAllWeakRef() { + for (auto iter = weak_refs_.begin(); iter != weak_refs_.end();) { + if (auto ref = iter->lock()) { + ADT_CHECK(ref->weak_ref_iter().has_value()); + ADT_CHECK(iter == ref->weak_ref_iter().value()); + ADT_LET_CONST_REF(next_iter, this->EraseWeakRef(ref.get())); + iter = next_iter; + } else { + ++iter; + } + } + return adt::Ok{}; + } + WeakRefList weak_refs_; +}; + +} // namespace ap::memory diff --git a/paddle/ap/include/memory/circlable_ref_list_base.h b/paddle/ap/include/memory/circlable_ref_list_base.h new file mode 100644 index 00000000000000..b8d15efcbe1c75 --- /dev/null +++ b/paddle/ap/include/memory/circlable_ref_list_base.h @@ -0,0 +1,43 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/ap/include/adt/adt.h" + +namespace ap::memory { + +class CirclableRefImplBase; + +class CirclableRefListBase { + public: + CirclableRefListBase(const CirclableRefListBase&) = delete; + CirclableRefListBase(CirclableRefListBase&&) = delete; + virtual ~CirclableRefListBase() = default; + + using WeakRefList = std::list>; + + virtual WeakRefList::iterator AddWeakRef( + const std::shared_ptr& ref) = 0; + + virtual adt::Result EraseWeakRef( + CirclableRefImplBase* ref) = 0; + + protected: + CirclableRefListBase() = default; +}; + +} // namespace ap::memory diff --git a/paddle/ap/include/memory/guard.h b/paddle/ap/include/memory/guard.h new file mode 100644 index 00000000000000..f343b1be41a416 --- /dev/null +++ b/paddle/ap/include/memory/guard.h @@ -0,0 +1,35 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/memory/circlable_ref_list.h" + +namespace ap::memory { + +class Guard final { + public: + Guard(const Guard&) = delete; + Guard(Guard&&) = delete; + Guard() : circlable_ref_list_(std::make_shared()) {} + + const std::shared_ptr& circlable_ref_list() const { + return circlable_ref_list_; + } + + private: + std::shared_ptr circlable_ref_list_; +}; + +} // namespace ap::memory diff --git a/paddle/ap/include/paddle/builtin_frame_util.h b/paddle/ap/include/paddle/builtin_frame_util.h new file mode 100644 index 00000000000000..e3d9051727333e --- /dev/null +++ b/paddle/ap/include/paddle/builtin_frame_util.h @@ -0,0 +1,33 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/builtin_frame_util.h" + +namespace ap::paddle { + +template +ap::axpr::AttrMap MakeBuiltinFrameAttrMap() { + ap::axpr::AttrMap attr_map; + auto Insert = [&](const std::string& k, const ValueT& v) { + attr_map->Set(k, v); + }; + ap::axpr::VisitEachBuiltinFrameAttr(Insert); + return attr_map; +} + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/const_meta_tensor_ptr.h b/paddle/ap/include/paddle/const_meta_tensor_ptr.h new file mode 100644 index 00000000000000..ec8643b1e223bf --- /dev/null +++ b/paddle/ap/include/paddle/const_meta_tensor_ptr.h @@ -0,0 +1,36 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/phi/core/meta_tensor.h" + +namespace ap::paddle { + +using ConstMetaTensorPtr = const ::phi::MetaTensor*; + +} + +namespace ap::axpr { + +template <> +struct TypeImpl : public std::monostate { + using std::monostate::monostate; + + const char* Name() const { return "ConstMetaTensorPtr"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/paddle/const_meta_tensor_ptr_method_class.h b/paddle/ap/include/paddle/const_meta_tensor_ptr_method_class.h new file mode 100644 index 00000000000000..a53e71cfd686c4 --- /dev/null +++ b/paddle/ap/include/paddle/const_meta_tensor_ptr_method_class.h @@ -0,0 +1,84 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/paddle/const_meta_tensor_ptr.h" +#include "paddle/ap/include/paddle/ddim_method_class.h" + +namespace ap::paddle { + +struct ConstMetaTensorPtrMethodClass { + using This = ConstMetaTensorPtrMethodClass; + using Self = ConstMetaTensorPtr; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const auto* ptr = self; + ss << "<" << axpr::TypeImpl{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return reinterpret_cast(self); + } + + static adt::Result GetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + const auto& attr_name_val = args.at(0); + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + if (attr_name == "dtype") { + return This{}.GetDtype(self); + } + if (attr_name == "dims") { + return This{}.GetDims(self); + } + return adt::errors::AttributeError{ + std::string() + "'ConstMetaTensorPtr' object has no attribute '" + + attr_name + "'."}; + } + + adt::Result GetDims(const Self& self) { + return GetDDimClass().New(self->dims()); + } + + adt::Result GetDtype(const Self& self) { + ADT_LET_CONST_REF(dtype, axpr::GetDataTypeFromPhiDataType(self->dtype())); + return dtype; + } +}; + +inline axpr::TypeImpl> +GetConstMetaTensorPtrClass() { + using Impl = ConstMetaTensorPtrMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "ConstMetaTensorPtr", [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__getattr__", &Impl::GetAttr); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/const_std_vector_const_meta_tensor_ptr_ptr_method_class.h b/paddle/ap/include/paddle/const_std_vector_const_meta_tensor_ptr_ptr_method_class.h new file mode 100644 index 00000000000000..cf964b87c80f10 --- /dev/null +++ b/paddle/ap/include/paddle/const_std_vector_const_meta_tensor_ptr_ptr_method_class.h @@ -0,0 +1,79 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/paddle/const_meta_tensor_ptr_method_class.h" + +namespace ap::paddle { + +struct ConstStdVectorConstMetaTensorPtrPtrMethodClass { + using This = ConstStdVectorConstMetaTensorPtrPtrMethodClass; + using Self = const std::vector*; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self; + ss << ""; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return reinterpret_cast(self); + } + + static adt::Result GetItem( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + const auto& idx_val = args.at(0); + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_LET_CONST_REF(idx, idx_val.template CastTo()) + << adt::errors::TypeError{std::string() + + "vector indices must be integers, not " + + axpr::GetTypeName(idx_val)}; + int64_t index = idx; + if (index < 0) { + index += self->size(); + } + if (index >= 0 && index < self->size()) { + return CastItem(self->at(index)); + } + return adt::errors::IndexError{"vector index out of range"}; + } + + static adt::Result CastItem(const ConstMetaTensorPtr& elem) { + return GetConstMetaTensorPtrClass().New(elem); + } +}; + +inline axpr::TypeImpl> +GetConstStdVectorConstMetaTensorPtrPtrClass() { + using Impl = ConstStdVectorConstMetaTensorPtrPtrMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "ConstStdVectorConstMetaTensorPtrPtr", [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__getitem__", &Impl::GetItem); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/ddim.h b/paddle/ap/include/paddle/ddim.h new file mode 100644 index 00000000000000..85c6f8e8b86d3b --- /dev/null +++ b/paddle/ap/include/paddle/ddim.h @@ -0,0 +1,36 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/common/ddim.h" + +namespace ap::paddle { + +using DDim = ::common::DDim; + +} + +namespace ap::axpr { + +template <> +struct TypeImpl : public std::monostate { + using std::monostate::monostate; + + const char* Name() const { return "DDim"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/paddle/ddim_method_class.h b/paddle/ap/include/paddle/ddim_method_class.h new file mode 100644 index 00000000000000..ed353010a7757b --- /dev/null +++ b/paddle/ap/include/paddle/ddim_method_class.h @@ -0,0 +1,79 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/paddle/ddim.h" + +namespace ap::paddle { + +struct DDimMethodClass { + using This = DDimMethodClass; + using Self = DDim; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + ss << "["; + for (int i = 0; i < self.size(); ++i) { + if (i > 0) { + ss << ", "; + } + ss << self.at(i); + } + ss << "]"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + int64_t hash_value = 0; + for (int i = 0; i < self.size(); ++i) { + hash_value = adt::hash_combine(hash_value, self.at(i)); + } + return hash_value; + } + + static adt::Result GetItem( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + const auto& index_val = args.at(0); + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_LET_CONST_REF(index, index_val.template TryGet()) + << adt::errors::TypeError{std::string() + + "'DDim.__getitem__()' takes integers, not " + + axpr::GetTypeName(index_val) + "."}; + ADT_CHECK(index < self.size()) + << adt::errors::IndexError{"list index out of range"}; + return self.at(index); + } +}; + +inline axpr::TypeImpl> GetDDimClass() { + using Impl = DDimMethodClass; + static auto cls( + axpr::MakeBuiltinClass("DDim", [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__getitem__", &Impl::GetItem); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/indexed_ir_graph.h b/paddle/ap/include/paddle/indexed_ir_graph.h new file mode 100644 index 00000000000000..9b628fdba7e9ae --- /dev/null +++ b/paddle/ap/include/paddle/indexed_ir_graph.h @@ -0,0 +1,60 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/graph/node_arena.h" +#include "paddle/ap/include/paddle/indexed_ir_node.h" + +namespace ap::paddle { + +using IndexedIrNodeArena = graph::NodeArena; +using IndexedIrNodeArenaPtr = std::shared_ptr; + +struct PureElementwiseIndexedIrGraphImpl { + IndexedIrNodeArenaPtr node_arena; + // free values in fusion op block. + std::vector> inputs; + // yield values in fusion op block. + std::vector> yield_op_inputs; + // output values of fusion op. + std::vector outputs; + + std::unordered_map> + pir_value2indexed_ir_value; + + adt::Result> GetIndexedIrValue( + pir::Value value) const { + const auto& iter = this->pir_value2indexed_ir_value.find(value); + ADT_CHECK(iter != this->pir_value2indexed_ir_value.end()); + return iter->second; + } + + bool operator==(const PureElementwiseIndexedIrGraphImpl& other) const { + return this == &other; + } +}; + +ADT_DEFINE_RC(PureElementwiseIndexedIrGraph, PureElementwiseIndexedIrGraphImpl); + +using IndexedIrGraphImpl = std::variant; + +struct IndexedIrGraph : public IndexedIrGraphImpl { + using IndexedIrGraphImpl::IndexedIrGraphImpl; + + ADT_DEFINE_VARIANT_METHODS(IndexedIrGraphImpl); +}; + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/indexed_ir_graph_util.h b/paddle/ap/include/paddle/indexed_ir_graph_util.h new file mode 100644 index 00000000000000..6d9efc91172746 --- /dev/null +++ b/paddle/ap/include/paddle/indexed_ir_graph_util.h @@ -0,0 +1,224 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/index_expr/index_tuple_expr.h" +#include "paddle/ap/include/paddle/indexed_ir_graph.h" +#include "paddle/ap/include/paddle/pir_node.h" +#include "paddle/ap/include/paddle/pir_util.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +namespace ap::paddle { + +adt::Result CreatePureElementwiseIndexedIrGraph( + const PackedIrOp& pir_op, const index_expr::IndexTupleExpr& indexes_expr); + +adt::Result GetPackedIrOpInputsOutputs( + const PackedIrOp& pir_op, + std::vector* inputs, + std::vector* yield_op_inputs, + std::vector* outputs); + +namespace detail { + +struct CreatePureElementwiseIndexedIrGraphHelper { + struct Ctx { + std::unordered_map> value2node; + + bool Has(pir::Value value) const { + return this->value2node.find(value) != this->value2node.end(); + } + + void Insert(pir::Value value, const IndexedIrValue& node) { + this->value2node[value] = node; + } + + adt::Result> Get(pir::Value value) const { + const auto& iter = this->value2node.find(value); + ADT_CHECK(iter != this->value2node.end()); + return iter->second; + } + }; + + adt::Result Create( + const PackedIrOp& pir_op, + const index_expr::IndexTupleExpr& indexes_expr) { + Ctx ctx{}; + ADT_LET_CONST_REF(node_arena_ptr, + CreateNodeArena(&ctx, pir_op, indexes_expr)); + return CreatePureElementwiseIndexedIrGraph(node_arena_ptr, ctx, pir_op); + } + + adt::Result + CreatePureElementwiseIndexedIrGraph(const IndexedIrNodeArenaPtr& node_arena, + const Ctx& ctx, + const PackedIrOp& pir_op) { + std::vector inputs; + std::vector yield_op_inputs; + std::vector outputs; + ADT_RETURN_IF_ERR(GetPackedIrOpInputsOutputs( + pir_op, &inputs, &yield_op_inputs, &outputs)); + ADT_LET_CONST_REF(input_nodes, GetIndexedIrValues(ctx, inputs)); + ADT_LET_CONST_REF(yield_op_input_nodes, + GetIndexedIrValues(ctx, yield_op_inputs)); + return PureElementwiseIndexedIrGraph{ + node_arena, input_nodes, yield_op_input_nodes, outputs, ctx.value2node}; + } + + adt::Result>> GetIndexedIrValues( + const Ctx& ctx, const std::vector values) { + std::vector> ret; + ret.reserve(values.size()); + for (const auto& value : values) { + ADT_LET_CONST_REF(ir_value, ctx.Get(value)); + ret.emplace_back(ir_value); + } + return ret; + } + + adt::Result CreateNodeArena( + Ctx* ctx, + const PackedIrOp& pir_op, + const index_expr::IndexTupleExpr& indexes_expr) { + auto node_arena = std::make_shared(); + for (auto& op : *pir_op.fusion_op.block()) { + if (op.isa()) { + continue; + } + const auto& ir_op = InsertOpNode(node_arena, &op); + InsertValueNodes(ctx, node_arena, &op, indexes_expr); + ADT_RETURN_IF_ERR(ConnectOpOperandEdges(ctx, ir_op)); + ADT_RETURN_IF_ERR(ConnectOpResultEdges(ctx, ir_op)); + } + return node_arena; + } + + adt::Result ConnectOpResultEdges( + Ctx* ctx, const IndexedIrOp& ir_op) { + auto* op = ir_op->op; + for (int i = 0; i < op->num_results(); ++i) { + ADT_LET_CONST_REF(ir_value, ctx->Get(op->result(i))); + ADT_RETURN_IF_ERR( + ir_op->node.ConnectTo(ir_value->node, + graph::IndexedTag{}, + graph::UnindexedTag{})); + } + return adt::Ok{}; + } + + adt::Result ConnectOpOperandEdges( + Ctx* ctx, const IndexedIrOp& ir_op) { + auto* op = ir_op->op; + for (int i = 0; i < op->num_operands(); ++i) { + ADT_LET_CONST_REF(ir_value, ctx->Get(op->operand_source(i))); + ADT_RETURN_IF_ERR( + ir_value->node.ConnectTo(ir_op->node, + graph::UnindexedTag{}, + graph::IndexedTag{})); + } + return adt::Ok{}; + } + + void InsertValueNodes(Ctx* ctx, + const IndexedIrNodeArenaPtr& node_arena, + pir::Operation* op, + const index_expr::IndexTupleExpr& indexes_expr) { + VisitInOutValue(op, [&](pir::Value value) { + const auto& ir_node = node_arena->New([&](const auto& node) { + return IndexedIrValue{node, value, indexes_expr}; + }); + const auto& ir_value = + ir_node.template Get>(); + if (!ctx->Has(value)) { + ctx->Insert(value, ir_value); + } + }); + } + + template + void VisitInOutValue(pir::Operation* op, const DoEachT& DoEach) { + for (int i = 0; i < op->num_operands(); ++i) { + DoEach(op->operand_source(i)); + } + for (int i = 0; i < op->num_results(); ++i) { + DoEach(op->result(i)); + } + } + + IndexedIrOp InsertOpNode( + const IndexedIrNodeArenaPtr& node_arena, pir::Operation* op) { + const auto& ir_node = node_arena->New([&](const auto& node) { + return IndexedIrOp{node, op}; + }); + return ir_node.template Get>(); + } +}; + +} // namespace detail + +inline adt::Result CreatePureElementwiseIndexedIrGraph( + const PackedIrOp& pir_op, const index_expr::IndexTupleExpr& indexes_expr) { + detail::CreatePureElementwiseIndexedIrGraphHelper helper{}; + ADT_LET_CONST_REF(ir_graph, helper.Create(pir_op, indexes_expr)); + return ir_graph; +} + +namespace detail { + +struct GetPackedIrOpInputsOutputsHelper { + adt::Result GetPackedIrOpInputsOutputs( + const PackedIrOp& pir_op, + std::vector* inputs, + std::vector* yield_op_inputs, + std::vector* outputs) { + *inputs = ap::paddle::GetUsedExternalValue(*pir_op.fusion_op); + outputs->clear(); + outputs->reserve(pir_op.fusion_op->num_results()); + for (int i = 0; i < pir_op.fusion_op->num_results(); ++i) { + outputs->emplace_back(pir_op.fusion_op->result(i)); + } + bool found_yield_op = false; + for (const auto& op : *pir_op.fusion_op.block()) { + yield_op_inputs->clear(); + yield_op_inputs->reserve(op.num_operands()); + if (op.isa()) { + for (int i = 0; i < op.num_operands(); ++i) { + yield_op_inputs->emplace_back(op.operand_source(i)); + } + found_yield_op = true; + } + } + if (found_yield_op) { + return adt::Ok{}; + } else { + return adt::errors::ValueError{ + "No yield op have been found in fusion op block."}; + } + } +}; + +} // namespace detail + +inline adt::Result GetPackedIrOpInputsOutputs( + const PackedIrOp& pir_op, + std::vector* inputs, + std::vector* yield_op_inputs, + std::vector* outputs) { + detail::GetPackedIrOpInputsOutputsHelper helper{}; + return helper.GetPackedIrOpInputsOutputs( + pir_op, inputs, yield_op_inputs, outputs); +} + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/indexed_ir_node.h b/paddle/ap/include/paddle/indexed_ir_node.h new file mode 100644 index 00000000000000..5b33b2a0b52e75 --- /dev/null +++ b/paddle/ap/include/paddle/indexed_ir_node.h @@ -0,0 +1,105 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/index_expr/index_tuple_expr.h" +#include "paddle/pir/include/core/operation.h" +#include "paddle/pir/include/core/value.h" + +namespace ap::paddle { + +inline void ConvertToAZaz09_(std::string* str) { + for (int i = 0; i < str->size(); ++i) { + char* ch = &str->at(i); + if (*ch >= 'a' && *ch <= 'z') { + continue; + } + if (*ch >= 'A' && *ch <= 'Z') { + continue; + } + if (*ch >= '0' && *ch <= '9') { + continue; + } + *ch = '_'; + } +} + +inline std::string GetOpUniqueName(const pir::Operation* op) { + std::string op_name = op->name(); + ConvertToAZaz09_(&op_name); + return op_name + "_" + std::to_string(op->id()); +} + +template +struct IndexedIrValueImpl { + graph::Node node; + pir::Value value; + index_expr::IndexTupleExpr indexes_expr; + + std::string GetUniqueNameInsideNodeArena() const { + if (value.defining_op()) { + return GetOpUniqueName(value.defining_op()) + "_out_" + + std::to_string(node.node_id().value()); + } else { + return std::string() + "non_op_out_" + + std::to_string(node.node_id().value()); + } + } + + bool operator==(const IndexedIrValueImpl& other) const { + return this->value == other.value && + this->indexes_expr == other.indexes_expr; + } +}; + +template +ADT_DEFINE_RC(IndexedIrValue, IndexedIrValueImpl); + +template +struct IndexedIrOpImpl { + graph::Node node; + pir::Operation* op; + + std::string GetUniqueNameInsideNodeArena() const { + return GetOpUniqueName(op) + +"_" + std::to_string(node.node_id().value()); + } + + bool operator==(const IndexedIrOpImpl& other) const { + return this->op == other.op; + } +}; + +template +ADT_DEFINE_RC(IndexedIrOp, IndexedIrOpImpl); + +template +using IndexedIrNodeImpl = + std::variant, IndexedIrOp>; + +struct IndexedIrNode : public IndexedIrNodeImpl { + using IndexedIrNodeImpl::IndexedIrNodeImpl; + + ADT_DEFINE_VARIANT_METHODS(IndexedIrNodeImpl); + + const graph::Node& node() const { + return Match([](const auto& impl) -> const graph::Node& { + return impl->node; + }); + } +}; + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/meta_tensor_ptr.h b/paddle/ap/include/paddle/meta_tensor_ptr.h new file mode 100644 index 00000000000000..8f06d56da1b6b4 --- /dev/null +++ b/paddle/ap/include/paddle/meta_tensor_ptr.h @@ -0,0 +1,36 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/phi/core/meta_tensor.h" + +namespace ap::paddle { + +using MetaTensorPtr = ::phi::MetaTensor*; + +} + +namespace ap::axpr { + +template <> +struct TypeImpl : public std::monostate { + using std::monostate::monostate; + + const char* Name() const { return "MetaTensorPtr"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/paddle/meta_tensor_ptr_method_class.h b/paddle/ap/include/paddle/meta_tensor_ptr_method_class.h new file mode 100644 index 00000000000000..cf650a5acfc19d --- /dev/null +++ b/paddle/ap/include/paddle/meta_tensor_ptr_method_class.h @@ -0,0 +1,149 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/paddle/ddim.h" +#include "paddle/ap/include/paddle/ddim_method_class.h" +#include "paddle/ap/include/paddle/meta_tensor_ptr.h" + +namespace ap::paddle { + +struct MetaTensorPtrMethodClass { + using This = MetaTensorPtrMethodClass; + using Self = MetaTensorPtr; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const auto* ptr = self; + ss << "<" << axpr::TypeImpl{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return reinterpret_cast(self); + } + + static adt::Result GetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + const auto& attr_name_val = args.at(0); + ADT_LET_CONST_REF(attr_name, attr_name_val.template CastTo()); + if (attr_name == "dtype") { + return This{}.GetDtype(self); + } + if (attr_name == "dims") { + return This{}.GetDims(self); + } + return adt::errors::AttributeError{ + std::string() + "'MetaTensorPtr' object has no attribute '" + + attr_name + "'."}; + } + + adt::Result GetDims(const Self& self) { + return GetDDimClass().New(self->dims()); + } + + adt::Result GetDtype(const Self& self) { + ADT_LET_CONST_REF(dtype, axpr::GetDataTypeFromPhiDataType(self->dtype())); + return dtype; + } + + static adt::Result SetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 2); + const auto& attr_name_val = args.at(0); + ADT_LET_CONST_REF(attr_name, attr_name_val.template CastTo()); + if (attr_name == "dtype") { + return StaticSetDtype(self_val, args); + } + if (attr_name == "dims") { + return StaticSetDims(self_val, args); + } + return adt::errors::AttributeError{ + std::string() + "'MetaTensorPtr' object has no attribute '" + + attr_name + "'."}; + } + + static adt::Result StaticSetDtype( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(data_type, args.at(1).template CastTo()); + ADT_LET_CONST_REF(dtype, GetPhiDataTypeFromDataType(data_type)); + self->set_dtype(dtype); + return adt::Nothing{}; + } + + static adt::Result StaticSetDims( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + return This{}.SetDims(self, args.at(1)); + } + + adt::Result SetDims(const Self& self, + const axpr::Value& dims_val) { + return dims_val.Match( + [&](const DDim& ddims) -> adt::Result { + return SetDimsByDDim(self, ddims); + }, + [&](const adt::List& list) -> adt::Result { + return SetDimsByIntList(self, list); + }, + [&](const auto&) -> adt::Result { + return adt::errors::TypeError{"only DDim or list of int supported."}; + }); + } + + adt::Result SetDimsByDDim(const Self& self, const DDim& ddims) { + self->set_dims(ddims); + return adt::Nothing{}; + } + + adt::Result SetDimsByIntList( + const Self& self, const adt::List& list) { + std::vector dims{}; + dims.reserve(list->size()); + for (const auto& dim_val : *list) { + ADT_LET_CONST_REF(dim, dim_val.template CastTo()); + dims.push_back(dim); + } + self->set_dims(::common::make_ddim(dims)); + return adt::Nothing{}; + } +}; + +inline axpr::TypeImpl> +GetMetaTensorPtrClass() { + using Impl = MetaTensorPtrMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "MetaTensorPtr", [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__getattr__", &Impl::GetAttr); + Define("__setattr__", &Impl::SetAttr); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/op_cuda_code_gen_impl.h b/paddle/ap/include/paddle/op_cuda_code_gen_impl.h new file mode 100644 index 00000000000000..10e535dcd20a08 --- /dev/null +++ b/paddle/ap/include/paddle/op_cuda_code_gen_impl.h @@ -0,0 +1,1018 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/axpr/pointer_type_util.h" +#include "paddle/ap/include/code_gen/code_gen_ctx.h" +#include "paddle/ap/include/code_gen/dim_expr_kernel_arg_id.h" +#include "paddle/ap/include/code_gen/op_code_gen_ctx.h" +#include "paddle/ap/include/code_gen/op_cuda_gen_impl.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/topo_kind.h" +#include "paddle/ap/include/drr/value.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/index_expr/index_tuple_expr_cuda_code_generator.h" +#include "paddle/ap/include/ir_match/native_or_ref_ir_value.h" +#include "paddle/ap/include/paddle/indexed_ir_graph_util.h" +#include "paddle/ap/include/paddle/pir_graph_descriptor.h" +#include "paddle/ap/include/paddle/pir_node.h" +#include "paddle/ap/include/registry/registry.h" +#include "paddle/ap/include/registry/registry_mgr.h" +#include "paddle/ap/include/registry/registry_singleton.h" +#include "paddle/ap/include/registry/value.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/pir/include/core/builtin_type.h" + +namespace ap::paddle { + +struct OpCudaCodeGenImpl { + using BirNode = PirNode; + using OpCodeGenCtx = code_gen::OpCodeGenCtx; + using IrOp = code_gen::IrOp; + + using DrrValue = drr::Value; + using DrrNode = drr::Node; + using DrrGraphNode = graph::Node; + using DrrPackedIrOp = drr::PackedIrOp; + using DrrOptPackedIrOp = drr::OptPackedIrOp; + using DrrOptPackedIrOpOperand = drr::OptPackedIrOpOperand; + using DrrOptPackedIrOpResult = drr::OptPackedIrOpResult; + + using DrrTrivialFusionIrOpImpl = + std::variant; + struct DrrTrivialFusionIrOp : public DrrTrivialFusionIrOpImpl { + using DrrTrivialFusionIrOpImpl::DrrTrivialFusionIrOpImpl; + ADT_DEFINE_VARIANT_METHODS(DrrTrivialFusionIrOpImpl); + + DrrGraphNode node() const { + return Match([](const auto& impl) { return impl->node; }); + } + }; + + using DrrNativeIrValue = drr::NativeIrValue; + using DrrPackedIrValue = drr::PackedIrValue; + using IndexTupleExpr = index_expr::IndexTupleExpr; + + using GraphMatchCtx = ir_match::GraphMatchCtx; + + using Registry = registry::Registry; + + using ClassAttrs = axpr::ClassAttrs; + + using Function = axpr::Function; + + adt::Result ConvertFusionOpToClassAttrs( + const OpCodeGenCtx& op_code_gen_ctx, const IrOp& ir_op) { + using RetT = adt::Result; + return ir_op.Match( + [&](const PackedIrOp& packed_ir_op) -> RetT { + return PackedIrOpConvertFusionOpToClassAttrs(op_code_gen_ctx, + packed_ir_op); + }, + [&](const RefIrOp& ref_ir_op) -> RetT { + return RefIrOpConvertFusionOpToClassAttrs(op_code_gen_ctx, ref_ir_op); + }, + [&](const auto&) -> RetT { + return adt::errors::TypeError{ + std::string() + + "only packed ir op get supported in ConvertFusionOpToLambda."}; + }); + } + + adt::Result PackedIrOpConvertFusionOpToClassAttrs( + const OpCodeGenCtx& op_code_gen_ctx, const PackedIrOp& packed_ir_op) { + ADT_LET_CONST_REF( + index_tuple_expr, + GetPureElementwiseLoopIndexTupleExpr(op_code_gen_ctx, packed_ir_op)); + ADT_LET_CONST_REF( + ir_graph, + CreatePureElementwiseIndexedIrGraph(packed_ir_op, index_tuple_expr)); + ADT_LET_CONST_REF(init_func, + PackedIrOpMakeInitFuncByFusionOp( + op_code_gen_ctx, ir_graph, packed_ir_op)); + ADT_LET_CONST_REF(compute_func, + PackedIrOpMakeComputeFuncByFusionOp( + op_code_gen_ctx, ir_graph, packed_ir_op)); + ADT_LET_CONST_REF(load_from_register_func, + PackedIrOpMakeLoadFromRegisterFuncByFusionOp( + op_code_gen_ctx, ir_graph, packed_ir_op)); + ADT_LET_CONST_REF(store_to_register_func, + PackedIrOpMakeStoreToRegisterFuncByFusionOp( + op_code_gen_ctx, ir_graph, packed_ir_op)); + std::string class_name = "PackedIrOpClass"; + adt::List>> + empty_bases{}; + axpr::AttrMap methods{}; + methods->Set("__init__", init_func); + methods->Set("compute", compute_func); + methods->Set("load_from_register", load_from_register_func); + methods->Set("store_to_register", store_to_register_func); + return ClassAttrs{class_name, empty_bases, methods}; + } + + adt::Result GetPureElementwiseLoopIndexTupleExpr( + const OpCodeGenCtx& op_code_gen_ctx, const PackedIrOp& packed_ir_op) { + ADT_LET_CONST_REF( + shape, GetPureElementwiseLoopDimExpr(op_code_gen_ctx, packed_ir_op)); + return index_expr::IndexTupleExprDomain{shape}; + } + + adt::Result> GetPureElementwiseLoopDimExpr( + const OpCodeGenCtx& op_code_gen_ctx, const PackedIrOp& packed_ir_op) { + const auto& input_flags = op_code_gen_ctx->input_index_loop_anchor_flags; + { + ADT_LET_CONST_REF( + num_native_ir_inputs, + NumNativeIrInputBirValues(op_code_gen_ctx, packed_ir_op)); + ADT_CHECK(input_flags->size() == num_native_ir_inputs) + << adt::errors::TypeError{ + std::string() + + "len(input_index_loop_anchor_flags) should equal to number of " + "native ir inputs of fusion op. (" + + std::to_string(input_flags->size()) + " v.s. " + + std::to_string(num_native_ir_inputs) + ")"}; + } + const auto& output_flags = op_code_gen_ctx->output_index_loop_anchor_flags; + { + ADT_LET_CONST_REF( + num_native_ir_outputs, + NumNativeIrOutputBirValues(op_code_gen_ctx, packed_ir_op)); + ADT_CHECK(output_flags->size() == num_native_ir_outputs) + << adt::errors::TypeError{ + std::string() + + "len(output_index_loop_anchor_flags) should equal to number " + "of native ir outputs of fusion op. (" + + std::to_string(output_flags->size()) + " v.s. " + + std::to_string(num_native_ir_outputs) + ")"}; + } + using Shape = adt::List; + auto GetShape = [&](pir::Value value) -> adt::Result { + ADT_LET_CONST_REF(shape_ptr, NativeIrValue{value}.GetShapeDimExprsPtr()); + Shape shape; + shape->reserve(shape_ptr->size()); + shape->assign(shape_ptr->begin(), shape_ptr->end()); + return shape; + }; + std::optional opt_shape; + auto InitOrCheckShape = [&](pir::Value value) -> adt::Result { + ADT_LET_CONST_REF(shape, GetShape(value)); + if (opt_shape.has_value()) { + ADT_CHECK(opt_shape.value() == shape) << adt::errors::TypeError{ + "All loop anchors should have same shapes."}; + } else { + opt_shape = shape; + } + return adt::Ok{}; + }; + { + int input_idx = 0; + auto DoEachNativeInput = [&](pir::Value value) -> adt::Result { + if (input_flags->at(input_idx++).value()) { + ADT_RETURN_IF_ERR(InitOrCheckShape(value)); + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitNativeIrInputBirValue( + op_code_gen_ctx, packed_ir_op, DoEachNativeInput)); + } + { + int output_idx = 0; + auto DoEachNativeOutput = [&](pir::Value value) -> adt::Result { + if (output_flags->at(output_idx++).value()) { + ADT_RETURN_IF_ERR(InitOrCheckShape(value)); + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitNativeIrOutputBirValue( + op_code_gen_ctx, packed_ir_op, DoEachNativeOutput)); + } + ADT_CHECK(opt_shape.has_value()) << adt::errors::TypeError{ + "At least one flag should be set in input_index_loop_anchor_flags or " + "output_index_loop_anchor_flags"}; + return opt_shape.value(); + } + + adt::Result NumNativeIrInputBirValues( + const OpCodeGenCtx& op_code_gen_ctx, const PackedIrOp& packed_ir_op) { + std::size_t num_values = 0; + auto Acc = [&](pir::Value) -> adt::Result { + ++num_values; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + VisitNativeIrInputBirValue(op_code_gen_ctx, packed_ir_op, Acc)); + return num_values; + } + + adt::Result NumNativeIrOutputBirValues( + const OpCodeGenCtx& op_code_gen_ctx, const PackedIrOp& packed_ir_op) { + std::size_t num_values = 0; + auto Acc = [&](pir::Value) -> adt::Result { + ++num_values; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + VisitNativeIrOutputBirValue(op_code_gen_ctx, packed_ir_op, Acc)); + return num_values; + } + + adt::Result RefIrOpConvertFusionOpToClassAttrs( + const OpCodeGenCtx& op_code_gen_ctx, const RefIrOp& ref_ir_op) { + ADT_LET_CONST_REF( + init_func, RefIrOpMakeInitFuncByFusionOp(op_code_gen_ctx, ref_ir_op)); + ADT_LET_CONST_REF( + compute_func, + RefIrOpMakeComputeFuncByFusionOp(op_code_gen_ctx, ref_ir_op)); + ADT_LET_CONST_REF( + load_from_register_func, + RefIrOpMakeLoadFromRegisterFuncByFusionOp(op_code_gen_ctx, ref_ir_op)); + ADT_LET_CONST_REF( + store_to_register_func, + RefIrOpMakeStoreToRegisterFuncByFusionOp(op_code_gen_ctx, ref_ir_op)); + std::string class_name = "RefIrOpClass"; + adt::List>> + empty_bases{}; + axpr::AttrMap methods{}; + methods->Set("__init__", init_func); + methods->Set("compute", compute_func); + methods->Set("load_from_register", load_from_register_func); + methods->Set("store_to_register", load_from_register_func); + return ClassAttrs{class_name, empty_bases, methods}; + } + + adt::Result PackedIrOpMakeInitFuncByFusionOp( + const OpCodeGenCtx& op_code_gen_ctx, + const IndexedIrGraph& ir_graph, + const PackedIrOp& packed_ir_op) { + return ir_graph.Match([&](const auto& impl) -> adt::Result { + return PackedIrOpMakeInitFuncByFusionOpImpl( + op_code_gen_ctx, impl, packed_ir_op); + }); + } + + adt::Result PackedIrOpMakeStoreToRegisterFuncByFusionOp( + const OpCodeGenCtx& op_code_gen_ctx, + const IndexedIrGraph& ir_graph, + const PackedIrOp& packed_ir_op) { + return ir_graph.Match([&](const auto& impl) -> adt::Result { + return PackedIrOpMakeStoreToRegisterFuncByFusionOpImpl( + op_code_gen_ctx, impl, packed_ir_op); + }); + } + + adt::Result PackedIrOpMakeLoadFromRegisterFuncByFusionOp( + const OpCodeGenCtx& op_code_gen_ctx, + const IndexedIrGraph& ir_graph, + const PackedIrOp& packed_ir_op) { + return ir_graph.Match([&](const auto& impl) -> adt::Result { + return PackedIrOpMakeLoadFromRegisterFuncByFusionOpImpl( + op_code_gen_ctx, impl, packed_ir_op); + }); + } + + adt::Result PackedIrOpMakeLoadFromRegisterFuncByFusionOpImpl( + const OpCodeGenCtx& op_code_gen_ctx, + const PureElementwiseIndexedIrGraph& ir_graph, + const PackedIrOp& packed_ir_op) { + axpr::LambdaExprBuilder lmbd; + auto GetMapFunc = [&](auto& ctx) -> axpr::AnfExpr { + auto& value_class_var = + ctx.Var("self").Attr("class_factory").Attr("get_value_class").Call(); + auto& name_var = ctx.Var("indexed_ir_node_info_tuple").At(0); + auto& index_tuple_expr_var = ctx.Var("indexed_ir_node_info_tuple").At(1); + auto& dtype_var = ctx.Var("indexed_ir_node_info_tuple").At(2); + auto& input_var = value_class_var.Call( + index_tuple_expr_var, dtype_var, ctx.Var("input_local_var_name")); + return ctx.Var(axpr::kBuiltinList()).Call(name_var, input_var); + }; + using AnfExprs = std::vector; + auto GetAllInputIndexedIrNodeInfo = + [&](auto* ctx) -> adt::Result { + AnfExprs ret; + auto DoEachNativeIrValue = + [&](pir::Value ir_value) -> adt::Result { + AnfExprs indexed_ir_info_tuple; + ADT_LET_CONST_REF(dtype, ConvertToDataType(ir_value)); + for (const auto& input : ir_graph->inputs) { + if (input->value == ir_value) { + auto& info_var = + ctx->Var(axpr::kBuiltinList()) + .Call(ctx->String(input->GetUniqueNameInsideNodeArena()), + ctx->Var("self").Attr("loop_index_tuple_expr"), + ctx->Var("DataType").Attr(dtype.Name())); + indexed_ir_info_tuple.emplace_back( + static_cast(info_var)); + } + } + auto& indexed_ir_info_var = + ctx->Var(axpr::kBuiltinList()).Apply(indexed_ir_info_tuple); + ret.emplace_back(static_cast(indexed_ir_info_var)); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitNativeIrInputBirValue( + op_code_gen_ctx, packed_ir_op, DoEachNativeIrValue)); + return ret; + }; + auto GetBody = [&](auto& ctx) -> adt::Result { + const auto& map_func_var_name = ctx.NewTmpVarName(); + ctx.Var(map_func_var_name) = + lmbd.Lambda({"indexed_ir_node_info_tuple"}, GetMapFunc); + ADT_LET_CONST_REF(indexed_nodes, GetAllInputIndexedIrNodeInfo(&ctx)); + auto& indexed_nodes_var = + ctx.Var(axpr::kBuiltinList()).Apply(indexed_nodes); + auto& native_input_indexed_nodes_var = + indexed_nodes_var.At(ctx.Var("native_input_index")); + auto& items_var = ctx.Var("map").Call(ctx.Var(map_func_var_name), + native_input_indexed_nodes_var); + auto& ret = ctx.Var("OrderedDict").Call(items_var); + return static_cast>(ret); + }; + ADT_LET_CONST_REF(anf_expr, + lmbd.TryLambda({"self", + "code_gen_ctx", + "input_local_var_name", + "native_input_index"}, + GetBody)); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + ADT_LET_CONST_REF( + atomic, core_expr.template TryGet>()); + ADT_LET_CONST_REF(lambda, + atomic.template TryGet>()); + return Function{lambda, std::nullopt}; + } + + adt::Result PackedIrOpMakeStoreToRegisterFuncByFusionOpImpl( + const OpCodeGenCtx& op_code_gen_ctx, + const PureElementwiseIndexedIrGraph& ir_graph, + const PackedIrOp& packed_ir_op) { + ADT_CHECK(ir_graph->yield_op_inputs.size() == ir_graph->outputs.size()); + auto GetOutputIndex = + [&](pir::Value output) -> adt::Result> { + for (int i = 0; i < ir_graph->outputs.size(); ++i) { + if (output == ir_graph->outputs.at(i)) { + return i; + } + } + return std::nullopt; + }; + axpr::LambdaExprBuilder lmbd; + using AnfExprs = std::vector; + auto GetAllOutputIndexedIrNodeInfo = + [&](auto* ctx) -> adt::Result { + AnfExprs ret; + auto DoEachNativeIrValue = + [&](pir::Value ir_value) -> adt::Result { + ADT_LET_CONST_REF(dtype, ConvertToDataType(ir_value)); + ADT_LET_CONST_REF(opt_idx, GetOutputIndex(ir_value)); + ADT_CHECK(opt_idx.has_value()); + const auto& output = ir_graph->yield_op_inputs.at(opt_idx.value()); + auto& indexed_ir_info_tuple = + ctx->Var(axpr::kBuiltinList()) + .Call(ctx->String(output->GetUniqueNameInsideNodeArena()), + ctx->Var("self").Attr("loop_index_tuple_expr"), + ctx->Var("DataType").Attr(dtype.Name())); + ret.emplace_back(indexed_ir_info_tuple); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitNativeIrOutputBirValue( + op_code_gen_ctx, packed_ir_op, DoEachNativeIrValue)); + return ret; + }; + + auto GetBody = [&](auto& ctx) -> adt::Result { + ADT_LET_CONST_REF(indexed_nodes, GetAllOutputIndexedIrNodeInfo(&ctx)); + auto& indexed_nodes_var = + ctx.Var(axpr::kBuiltinList()).Apply(indexed_nodes); + auto& native_output_indexed_node_var = + indexed_nodes_var.At(ctx.Var("native_output_index")); + auto& name_var = native_output_indexed_node_var.At(0); + auto& output_var = ctx.Var("compute_results").At(name_var); + auto& value_class_var = + ctx.Var("self").Attr("class_factory").Attr("get_value_class").Call(); + auto& index_tuple_expr_var = native_output_indexed_node_var.At(1); + auto& dtype_var = native_output_indexed_node_var.At(2); + auto& store_var = value_class_var.Call( + index_tuple_expr_var, dtype_var, ctx.Var("out_value_local_var_name")); + ctx.Var("code_gen_ctx").Attr("assign").Call(store_var, output_var); + return ctx.None(); + }; + ADT_LET_CONST_REF(anf_expr, + lmbd.TryLambda({"self", + "code_gen_ctx", + "compute_results", + "out_value_local_var_name", + "native_output_index"}, + GetBody)); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + ADT_LET_CONST_REF( + atomic, core_expr.template TryGet>()); + ADT_LET_CONST_REF(lambda, + atomic.template TryGet>()); + return Function{lambda, std::nullopt}; + } + + adt::Result PackedIrOpMakeComputeFuncByFusionOp( + const OpCodeGenCtx& op_code_gen_ctx, + const IndexedIrGraph& ir_graph, + const PackedIrOp& packed_ir_op) { + return ir_graph.Match([&](const auto& impl) -> adt::Result { + return PackedIrOpMakeComputeFuncByFusionOpImpl( + op_code_gen_ctx, impl, packed_ir_op); + }); + } + + adt::Result PackedIrOpMakeComputeFuncByFusionOpImpl( + const OpCodeGenCtx& op_code_gen_ctx, + const PureElementwiseIndexedIrGraph& ir_graph, + const PackedIrOp& packed_ir_op) { + axpr::LambdaExprBuilder lmbd; + using Ok = adt::Result; + auto UnpackInputs = [&](auto* ctx) -> Ok { + for (const auto& input : ir_graph->inputs) { + const auto& name = input->GetUniqueNameInsideNodeArena(); + ctx->Var(name) = ctx->Var("inputs").At(ctx->String(name)); + } + return adt::Ok{}; + }; + auto ComputeNativeOpCodeGen = [&](auto* ctx, + const auto& indexed_ir_op) -> Ok { + ADT_LET_CONST_REF(input_var_names, GetInputVarNames(indexed_ir_op)); + const auto& indexed_ir_op_name = + indexed_ir_op->GetUniqueNameInsideNodeArena(); + ADT_LET_CONST_REF(output_var_names, GetOutputVarNames(indexed_ir_op)); + std::vector args{ctx->Var("code_gen_ctx")}; + args.reserve(input_var_names.size() + 1); + for (const auto& input_var_name : input_var_names) { + args.push_back(ctx->Var(input_var_name)); + } + auto& outputs_var = ctx->Var("self").Attr(indexed_ir_op_name).Apply(args); + for (int i = 0; i < output_var_names.size(); ++i) { + const auto& output_var_name = output_var_names.at(i); + ctx->Var(output_var_name) = outputs_var.At(i); + } + return adt::Ok{}; + }; + auto PackedOutputs = [&](auto* ctx) -> adt::Result { + std::vector yield_op_input_items; + yield_op_input_items.reserve(ir_graph->yield_op_inputs.size()); + for (const auto& yield_op_input : ir_graph->yield_op_inputs) { + const auto& name = yield_op_input->GetUniqueNameInsideNodeArena(); + const auto& pair = ctx->Var(axpr::kBuiltinList()) + .Call(ctx->String(name), ctx->Var(name)); + yield_op_input_items.emplace_back(static_cast(pair)); + } + const auto& items = + ctx->Var(axpr::kBuiltinList()).Call(yield_op_input_items); + return ctx->Call("OrderedDict", items); + }; + auto GetBody = [&](auto& ctx) -> adt::Result { + auto* ctx_ptr = &ctx; + ADT_RETURN_IF_ERR(UnpackInputs(ctx_ptr)); + ADT_RETURN_IF_ERR( + VisitIndexedIrOp(ir_graph, [&](const auto& indexed_ir_op) -> Ok { + return ComputeNativeOpCodeGen(ctx_ptr, indexed_ir_op); + })); + ADT_LET_CONST_REF(packed_outputs, PackedOutputs(ctx_ptr)); + return packed_outputs; + }; + std::vector arg_names{"self", "code_gen_ctx", "inputs"}; + ADT_LET_CONST_REF(anf_expr, lmbd.TryLambda(arg_names, GetBody)); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + ADT_LET_CONST_REF( + atomic, core_expr.template TryGet>()); + ADT_LET_CONST_REF(lambda, + atomic.template TryGet>()); + return Function{lambda, std::nullopt}; + } + + adt::Result> GetInputVarNames( + const IndexedIrOp& indexed_ir_op) const { + ADT_LET_CONST_REF(upstreams, indexed_ir_op->node.UpstreamNodes()); + std::vector ret{}; + ret.reserve(upstreams.size()); + auto DoEach = [&](const auto& node) -> adt::Result { + ADT_LET_CONST_REF(ir_node, node.Get()); + ADT_LET_CONST_REF( + ir_value, ir_node.template TryGet>()); + ret.push_back(ir_value->GetUniqueNameInsideNodeArena()); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(upstreams.VisitNodes(DoEach)); + return ret; + } + + adt::Result> GetOutputVarNames( + const IndexedIrOp& indexed_ir_op) { + ADT_LET_CONST_REF(downstreams, indexed_ir_op->node.DownstreamNodes()); + std::vector ret{}; + ret.reserve(downstreams.size()); + auto DoEach = [&](const auto& node) -> adt::Result { + ADT_LET_CONST_REF(ir_node, node.Get()); + ADT_LET_CONST_REF( + ir_value, ir_node.template TryGet>()); + ret.push_back(ir_value->GetUniqueNameInsideNodeArena()); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(downstreams.VisitNodes(DoEach)); + return ret; + } + + adt::Result PackedIrOpMakeInitFuncByFusionOpImpl( + const OpCodeGenCtx& op_code_gen_ctx, + const PureElementwiseIndexedIrGraph& ir_graph, + const PackedIrOp& packed_ir_op) { + axpr::LambdaExprBuilder lmbd; + using Ok = adt::Result; + auto ConstructNativeOpCodeGen = [&](auto* ctx, + const auto& indexed_ir_op) -> Ok { + const auto& op_name = indexed_ir_op->op->name(); + auto& class_var = ctx->Var("get_native_op_code_generator_class") + .Call(ctx->String(op_name)); + { + std::vector input_dtype_anf_exprs; + for (int i = 0; i < indexed_ir_op->op->num_operands(); ++i) { + ADT_LET_CONST_REF( + dtype, ConvertToDataType(indexed_ir_op->op->operand_source(i))); + const auto& dtype_var = ctx->Var("DataType").Attr(dtype.Name()); + input_dtype_anf_exprs.emplace_back( + static_cast(dtype_var)); + } + ctx->Var("input_dtypes") = + ctx->Call(axpr::kBuiltinList(), input_dtype_anf_exprs); + } + { + std::vector output_dtype_anf_exprs; + for (int i = 0; i < indexed_ir_op->op->num_results(); ++i) { + ADT_LET_CONST_REF(dtype, + ConvertToDataType(indexed_ir_op->op->result(i))); + const auto& dtype_var = ctx->Var("DataType").Attr(dtype.Name()); + output_dtype_anf_exprs.emplace_back( + static_cast(dtype_var)); + } + ctx->Var("output_dtypes") = + ctx->Call(axpr::kBuiltinList(), output_dtype_anf_exprs); + } + { + std::vector input_index_tuple_exprs; + input_index_tuple_exprs.reserve(indexed_ir_op->op->num_operands()); + for (int i = 0; i < indexed_ir_op->op->num_operands(); ++i) { + input_index_tuple_exprs.emplace_back( + ctx->Var("loop_index_tuple_expr")); + } + ctx->Var("input_index_tuple_exprs") = + ctx->Call(axpr::kBuiltinList(), input_index_tuple_exprs); + } + { + std::vector output_index_tuple_exprs; + output_index_tuple_exprs.reserve(indexed_ir_op->op->num_results()); + for (int i = 0; i < indexed_ir_op->op->num_results(); ++i) { + output_index_tuple_exprs.emplace_back( + ctx->Var("loop_index_tuple_expr")); + } + ctx->Var("output_index_tuple_exprs") = + ctx->Call(axpr::kBuiltinList(), output_index_tuple_exprs); + } + const auto& indexed_op_name = + indexed_ir_op->GetUniqueNameInsideNodeArena(); + axpr::AnfExpr indexed_op = + class_var.Call(ctx->Var("index_expr_code_gen"), + ctx->String(indexed_op_name), + ctx->Var("input_dtypes"), + ctx->Var("output_dtypes"), + ctx->Var("input_index_tuple_exprs"), + ctx->Var("output_index_tuple_exprs"), + /*attrs*/ ctx->None()); + ctx->Var("self").SetAttr(indexed_op_name, indexed_op); + return adt::Ok{}; + }; + auto GetBody = [&](auto& ctx) -> adt::Result { + ctx.Var("self").SetAttr("class_factory", ctx.Var("class_factory")); + ctx.Var("self").SetAttr("loop_index_tuple_expr", + ctx.Var("loop_index_tuple_expr")); + ctx.Var("index_expr_code_generator_class") = + ctx.Var("class_factory") + .Attr("get_index_expr_code_generator_class") + .Call(); + ctx.Var("index_expr_code_gen") = + ctx.Var("index_expr_code_generator_class") + .Call(ctx.Var("loop_var_names")); + ctx.Var("get_native_op_code_generator_class") = + ctx.Var("class_factory") + .Attr("get_native_op_code_generator_class") + .Call(); + auto* ctx_ptr = &ctx; + ADT_RETURN_IF_ERR( + VisitIndexedIrOp(ir_graph, [&](const auto& indexed_ir_op) -> Ok { + return ConstructNativeOpCodeGen(ctx_ptr, indexed_ir_op); + })); + return ctx.None(); + }; + ADT_LET_CONST_REF(anf_expr, + lmbd.TryLambda({"self", + "class_factory", + "loop_index_tuple_expr", + "loop_var_names"}, + GetBody)); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + ADT_LET_CONST_REF( + atomic, core_expr.template TryGet>()); + ADT_LET_CONST_REF(lambda, + atomic.template TryGet>()); + return Function{lambda, std::nullopt}; + } + + template + adt::Result VisitIndexedIrOp( + const PureElementwiseIndexedIrGraph& ir_graph, + const DoEachIndexIrNodeT& DoEachIndexIrNode) { + for (const auto& node : ir_graph->node_arena->nodes()) { + if (node.template Has>()) { + ADT_RETURN_IF_ERR( + DoEachIndexIrNode(node.template Get>())); + } + } + return adt::Ok{}; + } + + adt::Result RefIrOpMakeInitFuncByFusionOp( + const OpCodeGenCtx& op_code_gen_ctx, const RefIrOp& ref_ir_op) { + axpr::LambdaExprBuilder lmbd; + auto GetBody = [](auto& ctx) { + ctx.Var("self").SetAttr("class_factory", ctx.Var("class_factory")); + ctx.Var("self").SetAttr("loop_index_tuple_expr", + ctx.Var("loop_index_tuple_expr")); + return ctx.None(); + }; + const auto& anf_expr = lmbd.Lambda( + {"self", "class_factory", "loop_index_tuple_expr", "loop_var_names"}, + GetBody); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + ADT_LET_CONST_REF( + atomic, core_expr.template TryGet>()); + ADT_LET_CONST_REF(lambda, + atomic.template TryGet>()); + return Function{lambda, std::nullopt}; + } + + adt::Result RefIrOpMakeComputeFuncByFusionOp( + const OpCodeGenCtx& op_code_gen_ctx, const RefIrOp& ref_ir_op) { + axpr::LambdaExprBuilder lmbd; + auto GetBody = [](auto& ctx) -> axpr::AnfExpr { return ctx.Var("inputs"); }; + const auto& anf_expr = + lmbd.Lambda({"self", "code_gen_ctx", "inputs"}, GetBody); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + ADT_LET_CONST_REF( + atomic, core_expr.template TryGet>()); + ADT_LET_CONST_REF(lambda, + atomic.template TryGet>()); + return Function{lambda, std::nullopt}; + } + + adt::Result RefIrOpMakeLoadFromRegisterFuncByFusionOp( + const OpCodeGenCtx& op_code_gen_ctx, const RefIrOp& ref_ir_op) { + pir::Value value = ref_ir_op.ref_node_info->ir_value.value; + ADT_LET_CONST_REF(dtype, ConvertToDataType(value)); + axpr::LambdaExprBuilder lmbd; + auto GetBody = [&](auto& ctx) { + auto& value_class_var = + ctx.Var("self").Attr("class_factory").Attr("get_value_class").Call(); + auto& index_tuple_expr_var = + ctx.Var("self").Attr("loop_index_tuple_expr"); + auto& dtype_var = ctx.Var("DataType").Attr(dtype.Name()); + auto& input_var = value_class_var.Call( + index_tuple_expr_var, dtype_var, ctx.Var("input_local_var_name")); + return ctx.Var("OrderedDict") + .Call(ctx.Var(axpr::kBuiltinList()) + .Call(ctx.Var(axpr::kBuiltinList()) + .Call(ctx.String("sole_ir_value"), input_var))); + }; + const auto& anf_expr = lmbd.Lambda( + {"self", "code_gen_ctx", "input_local_var_name", "native_input_index"}, + GetBody); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + ADT_LET_CONST_REF( + atomic, core_expr.template TryGet>()); + ADT_LET_CONST_REF(lambda, + atomic.template TryGet>()); + return Function{lambda, std::nullopt}; + } + + adt::Result RefIrOpMakeStoreToRegisterFuncByFusionOp( + const OpCodeGenCtx& op_code_gen_ctx, const RefIrOp& ref_ir_op) { + pir::Value value = ref_ir_op.ref_node_info->ir_value.value; + ADT_LET_CONST_REF(dtype, ConvertToDataType(value)); + axpr::LambdaExprBuilder lmbd; + auto GetBody = [&](auto& ctx) { + auto& value_class_var = + ctx.Var("self").Attr("class_factory").Attr("get_value_class").Call(); + auto& index_tuple_expr_var = + ctx.Var("self").Attr("loop_index_tuple_expr"); + auto& dtype_var = ctx.Var("DataType").Attr(dtype.Name()); + auto& output_var = value_class_var.Call( + index_tuple_expr_var, dtype_var, ctx.Var("out_value_local_var_name")); + ctx.Var("code_gen_ctx") + .Attr("assign") + .Call(output_var, + ctx.Var("compute_results").At(ctx.String("sole_ir_value"))); + return ctx.None(); + }; + const auto& anf_expr = lmbd.Lambda({"self", + "code_gen_ctx", + "compute_results", + "out_value_local_var_name", + "native_output_index"}, + GetBody); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + ADT_LET_CONST_REF( + atomic, core_expr.template TryGet>()); + ADT_LET_CONST_REF(lambda, + atomic.template TryGet>()); + return Function{lambda, std::nullopt}; + } + + using NativeOrRefIrValue = ir_match::NativeOrRefIrValue; + + template + adt::Result VisitNativeIrInputBirValue( + const OpCodeGenCtx& op_code_gen_ctx, + const PackedIrOp& packed_ir_op, + const DoEachT& DoEach) { + ADT_LET_CONST_REF(graph_match_ctx, GetGraphMatchCtx(op_code_gen_ctx)); + ADT_LET_CONST_REF(drr_trivial_fusion_ir_op, + GetDrrTrivialFusionIrOp(graph_match_ctx, packed_ir_op)); + auto DoEachNativeValue = + [&](const auto& drr_ir_value) -> adt::Result { + ADT_LET_CONST_REF(value, GetPirValue(graph_match_ctx, drr_ir_value)); + return DoEach(value); + }; + auto DoEachPackedValue = + [&](const auto& drr_ir_value) -> adt::Result { + // Do nothing. + return adt::Ok{}; + }; + return VisitDrrTrivialFusionIrOpInput( + drr_trivial_fusion_ir_op, DoEachNativeValue, DoEachPackedValue); + } + + template + adt::Result VisitInputBirNativeIrValue( + const OpCodeGenCtx& op_code_gen_ctx, + const PackedIrOp& packed_ir_op, + const DoEachT& DoEach) { + ADT_LET_CONST_REF(graph_match_ctx, GetGraphMatchCtx(op_code_gen_ctx)); + ADT_LET_CONST_REF(drr_trivial_fusion_ir_op, + GetDrrTrivialFusionIrOp(graph_match_ctx, packed_ir_op)); + auto DoEachNativeValue = + [&](const auto& drr_ir_value) -> adt::Result { + ADT_LET_CONST_REF(value, GetPirValue(graph_match_ctx, drr_ir_value)); + return DoEach(value); + }; + auto DoEachPackedValue = + [&](const auto& drr_ir_value) -> adt::Result { + ADT_RETURN_IF_ERR( + VisitPackedPirValue(graph_match_ctx, drr_ir_value, DoEach)); + return adt::Ok{}; + }; + return VisitDrrTrivialFusionIrOpInput( + drr_trivial_fusion_ir_op, DoEachNativeValue, DoEachPackedValue); + } + + template + adt::Result VisitPackedPirValue(const GraphMatchCtx& match_ctx, + const DrrPackedIrValue& drr_ir_value, + const DoEachT& DoEach) { + auto DoEachPirNode = [&](const PirNode& pir_node) -> adt::Result { + ADT_LET_CONST_REF(pir_value, pir_node.template TryGet()); + ADT_RETURN_IF_ERR(DoEach(pir_value.value)); + return adt::Ok{}; + }; + const auto& node = drr_ir_value->node; + ADT_RETURN_IF_ERR( + match_ctx->VisitPackedBigGraphIrValueNode(node, DoEachPirNode)); + return adt::Ok{}; + } + + template + adt::Result VisitNativeIrOutputBirValue( + const OpCodeGenCtx& op_code_gen_ctx, + const PackedIrOp& packed_ir_op, + const DoEachT& DoEach) { + ADT_LET_CONST_REF(graph_match_ctx, GetGraphMatchCtx(op_code_gen_ctx)); + ADT_LET_CONST_REF(drr_trivial_fusion_ir_op, + GetDrrTrivialFusionIrOp(graph_match_ctx, packed_ir_op)); + auto DoEachNativeValue = + [&](const auto& drr_ir_value) -> adt::Result { + ADT_LET_CONST_REF(value, GetPirValue(graph_match_ctx, drr_ir_value)); + return DoEach(value); + }; + auto DoEachPackedValue = + [&](const auto& drr_ir_value) -> adt::Result { + // Do nothing. + return adt::Ok{}; + }; + return VisitDrrTrivialFusionIrOpOutput( + drr_trivial_fusion_ir_op, DoEachNativeValue, DoEachPackedValue); + } + + template + adt::Result VisitOutputNativeIrValue( + const OpCodeGenCtx& op_code_gen_ctx, + const PackedIrOp& packed_ir_op, + const DoEachT& DoEach) { + ADT_LET_CONST_REF(graph_match_ctx, GetGraphMatchCtx(op_code_gen_ctx)); + ADT_LET_CONST_REF(drr_trivial_fusion_ir_op, + GetDrrTrivialFusionIrOp(graph_match_ctx, packed_ir_op)); + auto DoEachNativeValue = + [&](const auto& drr_ir_value) -> adt::Result { + ADT_LET_CONST_REF(value, GetPirValue(graph_match_ctx, drr_ir_value)); + return DoEach(value); + }; + auto DoEachPackedValue = + [&](const auto& drr_ir_value) -> adt::Result { + ADT_RETURN_IF_ERR( + VisitPackedPirValue(graph_match_ctx, drr_ir_value, DoEach)); + return adt::Ok{}; + }; + return VisitDrrTrivialFusionIrOpOutput( + drr_trivial_fusion_ir_op, DoEachNativeValue, DoEachPackedValue); + } + + template + adt::Result VisitDrrTrivialFusionIrOpInput( + const DrrTrivialFusionIrOp& drr_trivial_fusion_ir_op, + const DoEachNativeValueT& DoEachNativeValue, + const DoEachPackedValueT DoEachPackedValue) { + LOG(ERROR) << "drr_trivial_fusion_ir_op: " + << graph::NodeDescriptor{}.DebugId( + drr_trivial_fusion_ir_op.node()); + auto DoEach = [&](const DrrGraphNode& node) -> adt::Result { + ADT_LET_CONST_REF(drr_node, node.Get()); + LOG(ERROR) << "drr_trivial_fusion_ir_op input: " + << graph::NodeDescriptor{}.DebugId(node); + return drr_node.Match( + [&](const DrrNativeIrValue& ir_value) -> adt::Result { + return DoEachNativeValue(ir_value); + }, + [&](const DrrPackedIrValue& ir_value) -> adt::Result { + return DoEachPackedValue(ir_value); + }, + [&](const auto&) -> adt::Result { + return adt::errors::ValueError{ + "the second connected upstreams of drr packed ir op should be " + "drr native ir values or drr packed ir values."}; + }); + }; + return VisitSecondConnectedUpstream(drr_trivial_fusion_ir_op.node(), + DoEach); + } + + template + adt::Result VisitDrrTrivialFusionIrOpOutput( + const DrrTrivialFusionIrOp& drr_trivial_fusion_ir_op, + const DoEachNativeValueT& DoEachNativeValue, + const DoEachPackedValueT DoEachPackedValue) { + auto DoEach = [&](const DrrGraphNode& node) -> adt::Result { + ADT_LET_CONST_REF(drr_node, node.Get()); + return drr_node.Match( + [&](const DrrNativeIrValue& ir_value) -> adt::Result { + return DoEachNativeValue(ir_value); + }, + [&](const DrrPackedIrValue& ir_value) -> adt::Result { + return DoEachPackedValue(ir_value); + }, + [&](const auto&) -> adt::Result { + return adt::errors::ValueError{ + "the second connected upstreams of drr packed ir op should be " + "drr native ir values or drr packed ir values."}; + }); + }; + return VisitSecondConnectedDownstream(drr_trivial_fusion_ir_op.node(), + DoEach); + } + + template + adt::Result VisitSecondConnectedUpstream(const DrrGraphNode& node, + const DoEachT& DoEach) { + auto DoEachUpstream = [&](const auto& upstream) -> adt::Result { + return VisitUpstream(upstream, DoEach); + }; + return VisitUpstream(node, DoEachUpstream); + } + + template + adt::Result VisitSecondConnectedDownstream(const DrrGraphNode& node, + const DoEachT& DoEach) { + auto DoEachUpstream = [&](const auto& downstream) -> adt::Result { + return VisitDownstream(downstream, DoEach); + }; + return VisitDownstream(node, DoEachUpstream); + } + + template + adt::Result VisitUpstream(const DrrGraphNode& node, + const DoEachT& DoEach) { + ADT_LET_CONST_REF(upstreams, node.UpstreamNodes()); + return upstreams.VisitNodes(DoEach); + } + + template + adt::Result VisitDownstream(const DrrGraphNode& node, + const DoEachT& DoEach) { + ADT_LET_CONST_REF(downstreams, node.DownstreamNodes()); + return downstreams.VisitNodes(DoEach); + } + + adt::Result GetPirValue( + const GraphMatchCtx& graph_match_ctx, + const DrrNativeIrValue& drr_native_ir_value) { + const auto& node = drr_native_ir_value->node; + ADT_LET_CONST_REF(pir_node, graph_match_ctx->GetSoleBigGraphNode(node)); + ADT_LET_CONST_REF(pir_value, pir_node.template TryGet()); + return pir_value.value; + } + + adt::Result GetDrrTrivialFusionIrOp( + const GraphMatchCtx& graph_match_ctx, const PackedIrOp& packed_ir_op) { + ADT_LET_CONST_REF(node, + graph_match_ctx->GetMatchedSmallGraphNode(packed_ir_op)); + ADT_LET_CONST_REF(drr_node, node.Get()); + using RetT = adt::Result; + return drr_node.Match( + [&](const DrrPackedIrOp& impl) -> RetT { return impl; }, + [&](const DrrOptPackedIrOp& impl) -> RetT { return impl; }, + [&](const auto&) -> RetT { + return adt::errors::NotImplementedError{ + "conversion from DrrNode to DrrTrivialFusionIrOp failed."}; + }); + } + + adt::Result GetGraphMatchCtx( + const OpCodeGenCtx& op_code_gen_ctx) const { + ADT_LET_CONST_REF(code_gen_ctx, + adt::WeakPtrLock(op_code_gen_ctx->code_gen_ctx)); + ADT_CHECK(code_gen_ctx->ir_match_ctx.has_value()); + const auto& ir_match_ctx = code_gen_ctx->ir_match_ctx.value(); + return ir_match_ctx->graph_match_ctx; + } + + adt::Result GetConstDataPointerType(pir::Value value) { + ADT_LET_CONST_REF(data_type, ConvertToDataType(value)); + return axpr::GetConstPointerTypeFromDataType(data_type); + } + + adt::Result GetMutableDataPointerType(pir::Value value) { + ADT_LET_CONST_REF(data_type, ConvertToDataType(value)); + return axpr::GetMutablePointerTypeFromDataType(data_type); + } + + adt::Result ConvertToDataType(pir::Value value) { + ADT_LET_CONST_REF(dtype, ConvertToPhiDataType(value)); + return ap::axpr::GetDataTypeFromPhiDataType(dtype); + } + + adt::Result ConvertToPhiDataType(pir::Value value) { + ADT_LET_CONST_REF(type, GetPirDataType(value)); + try { + return ::paddle::dialect::TransToPhiDataType(type); + } catch (const std::exception& e) { + return adt::errors::TypeError{ + "failed to cast from pir data type to phi data type."}; + } + } + + adt::Result GetPirDataType(pir::Value value) { + if (!value.type().isa()) { + return adt::errors::NotImplementedError{ + "pir value must be of DenseTensorType"}; + } + const auto dense_tensor_type = + value.type().dyn_cast(); + return dense_tensor_type.dtype(); + } +}; + +} // namespace ap::paddle + +namespace ap::code_gen { + +template <> +struct OpCudaCodeGenImpl + : public paddle::OpCudaCodeGenImpl {}; + +} // namespace ap::code_gen diff --git a/paddle/ap/include/paddle/pass/ap_drr_helper.h b/paddle/ap/include/paddle/pass/ap_drr_helper.h new file mode 100644 index 00000000000000..a251cff82a66b7 --- /dev/null +++ b/paddle/ap/include/paddle/pass/ap_drr_helper.h @@ -0,0 +1,56 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/drr/drr_interpreter.h" +#include "paddle/ap/include/drr/value.h" +#include "paddle/ap/include/registry/abstract_drr_pass_registry_item.h" + +namespace cinn::dialect::ir { + +struct ApDrrHelper { + public: + explicit ApDrrHelper(const std::weak_ptr& + circlable_ref_list); + using Function = ap::axpr::Value; + + using DrrNode = ap::drr::Node; + using DrrCtx = ap::drr::DrrCtx; + + ap::adt::Result Interpret( + const Function& function, const std::vector& args) { + return drr_interpreter_.Interpret(function, args); + } + + ap::adt::Result InterpretDrrCtxMaker( + const Function& lambda, const std::vector& args); + + ap::adt::Result Interpret(const Function& lambda, + const std::string& abstract_drr_pass_name); + + ap::adt::Result Interpret( + const ap::axpr::ClassAttrs& cls); + + ap::adt::Result CreateDrrCtxByDrrPassObj( + const ap::axpr::Value& drr_pass_obj); + + ap::drr::DrrInterpreter* mut_drr_interpreter() { return &drr_interpreter_; } + + private: + mutable ap::drr::DrrInterpreter drr_interpreter_; +}; + +} // namespace cinn::dialect::ir diff --git a/paddle/ap/include/paddle/pass/ap_kernel_define_helper.h b/paddle/ap/include/paddle/pass/ap_kernel_define_helper.h new file mode 100644 index 00000000000000..69f4fa40ad6f8c --- /dev/null +++ b/paddle/ap/include/paddle/pass/ap_kernel_define_helper.h @@ -0,0 +1,44 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/code_gen/code_gen_ctx.h" +#include "paddle/ap/include/code_gen/code_gen_result.h" +#include "paddle/ap/include/code_gen/value.h" +#include "paddle/ap/include/code_module/code_module.h" +#include "paddle/ap/include/paddle/pir_node.h" + +namespace cinn::dialect::ir { + +struct ApKernelDefineHelper { + std::weak_ptr circlable_ref_list_; + + explicit ApKernelDefineHelper( + const std::weak_ptr& circlable_ref_list) + : circlable_ref_list_(circlable_ref_list) {} + + using Function = ap::axpr::Value; + using CodeModule = ap::code_module::CodeModule; + using PirNode = ap::paddle::PirNode; + using CGValue = ap::code_gen::Value; + using CodeGenCtx = ap::code_gen::CodeGenCtx; + using CodeGenResult = ap::code_gen::CodeGenResult; + + ap::adt::Result Interpret(const Function& lambda, + const CodeGenCtx& code_gen_ctx); +}; + +} // namespace cinn::dialect::ir diff --git a/paddle/ap/include/paddle/pass/ap_lower_fusion_op_pass.h b/paddle/ap/include/paddle/pass/ap_lower_fusion_op_pass.h new file mode 100644 index 00000000000000..1254056ca335c7 --- /dev/null +++ b/paddle/ap/include/paddle/pass/ap_lower_fusion_op_pass.h @@ -0,0 +1,56 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/pir/include/pass/pass.h" + +namespace ap::memory { + +class CirclableRefListBase; + +} + +namespace ap::axpr { + +struct Value; + +} + +namespace cinn { +namespace dialect { +namespace ir { + +std::optional> +CreateApLowerFusionOpAbstractDrrPass( + const std::weak_ptr& circlable_ref_list); +std::optional> CreateApLowerFusionOpClassicDrrPass( + const std::weak_ptr& circlable_ref_list); + +std::optional> CreateAccessTopoDrrPass( + const std::weak_ptr& circlable_ref_list, + const std::string& drr_pass_tag, + std::optional steps_limit); + +std::optional> CreateCustomAccessTopoDrrPass( + const std::weak_ptr& circlable_ref_list, + const ap::axpr::Value& drr_pass, + std::optional steps_limit, + const ap::axpr::Value& mut_matched_pattern_as_programs); + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/ap/include/paddle/pass/ap_registry_helper.h b/paddle/ap/include/paddle/pass/ap_registry_helper.h new file mode 100644 index 00000000000000..b18b3d628e566e --- /dev/null +++ b/paddle/ap/include/paddle/pass/ap_registry_helper.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/registry/registry.h" + +namespace cinn::dialect::ir { + +struct ApRegistryHelper { + ap::adt::Result SingletonRegistry(); +}; + +} // namespace cinn::dialect::ir diff --git a/paddle/ap/include/paddle/pass/ir_helper.h b/paddle/ap/include/paddle/pass/ir_helper.h new file mode 100644 index 00000000000000..aa4d372805b356 --- /dev/null +++ b/paddle/ap/include/paddle/pass/ir_helper.h @@ -0,0 +1,19 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/ir_match/ir_match_ctx.h" +#include "paddle/ap/include/paddle/pir_node.h" diff --git a/paddle/ap/include/paddle/pass/ir_helper_method_class.h b/paddle/ap/include/paddle/pass/ir_helper_method_class.h new file mode 100644 index 00000000000000..3d63ef993944c8 --- /dev/null +++ b/paddle/ap/include/paddle/pass/ir_helper_method_class.h @@ -0,0 +1,42 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/callable_helper.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/drr/drr_value_helper.h" +#include "paddle/ap/include/paddle/pass/ap_drr_helper.h" +#include "paddle/ap/include/paddle/pass/ap_lower_fusion_op_pass.h" +#include "paddle/ap/include/paddle/pass/ir_helper.h" +#include "paddle/ap/include/paddle/pir/op_dialect.h" +#include "paddle/ap/include/paddle/pir/packed_ir_op_inner_source_pattern_helper.h" +#include "paddle/ap/include/paddle/pir/pass_manager_method_class.h" +#include "paddle/ap/include/paddle/pir/pass_method_class.h" +#include "paddle/ap/include/paddle/pir/program_method_class.h" +#include "paddle/ap/include/paddle/pir_node_helper.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h" +#include "paddle/fluid/pir/utils/general_functions.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/pir/include/core/builtin_op.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +namespace ap::paddle { + +void ForceLinkIrTools(); + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/phi/ap_infer_meta_helper.h b/paddle/ap/include/paddle/phi/ap_infer_meta_helper.h new file mode 100644 index 00000000000000..e39ae72c322863 --- /dev/null +++ b/paddle/ap/include/paddle/phi/ap_infer_meta_helper.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/phi/core/meta_tensor.h" + +namespace phi { + +namespace adt = ap::adt; + +struct ApInferMetaHelper { + using CoreExpr = ap::axpr::CoreExpr; + using Lambda = ap::axpr::Lambda; + + adt::Result InferMeta(const std::string& lambda, + const std::vector* inputs, + std::vector* outputs); +}; + +} // namespace phi diff --git a/paddle/ap/include/paddle/phi/device_ctx.h b/paddle/ap/include/paddle/phi/device_ctx.h new file mode 100644 index 00000000000000..65365745f20e98 --- /dev/null +++ b/paddle/ap/include/paddle/phi/device_ctx.h @@ -0,0 +1,41 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/kernel_dispatch/device_ctx.h" + +namespace ap::paddle { + +template +class DeviceCtx : public kernel_dispatch::DeviceCtxImpl { + private: + const PhiDeviceCtx* phi_device_ctx_; + using StreamT = decltype(std::declval().stream()); + std::optional stream_; + + public: + explicit DeviceCtx(const PhiDeviceCtx* phi_device_ctx) + : phi_device_ctx_(phi_device_ctx) {} + + adt::Result GetStreamAddrAsVoidPtr() override { + if (!stream_.has_value()) { + stream_ = phi_device_ctx_->stream(); + } + void* stream_ptr = &stream_.value(); + return axpr::PointerValue{stream_ptr}; + } +}; + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/phi/kernel_define_helper.h b/paddle/ap/include/paddle/phi/kernel_define_helper.h new file mode 100644 index 00000000000000..cc6a9f55f83e06 --- /dev/null +++ b/paddle/ap/include/paddle/phi/kernel_define_helper.h @@ -0,0 +1,33 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/code_module/code_module.h" + +namespace phi { + +namespace adt = ap::adt; + +struct KernelDefineHelper { + using CoreExpr = ap::axpr::CoreExpr; + using Lambda = ap::axpr::Lambda; + using CodeModule = ap::code_module::CodeModule; + + adt::Result InterpretKernelDefineLambda( + const Lambda& code_module_lambda); +}; + +} // namespace phi diff --git a/paddle/ap/include/paddle/phi/kernel_dispatch_helper.h b/paddle/ap/include/paddle/phi/kernel_dispatch_helper.h new file mode 100644 index 00000000000000..7d40216d5f4147 --- /dev/null +++ b/paddle/ap/include/paddle/phi/kernel_dispatch_helper.h @@ -0,0 +1,42 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/kernel_dispatch/value.h" +#include "paddle/ap/include/memory/circlable_ref_list_base.h" + +namespace phi { + +namespace adt = ap::adt; + +class KernelDispatchHelper { + std::shared_ptr circlable_ref_list_; + + public: + KernelDispatchHelper(); + + using CoreExpr = ap::axpr::CoreExpr; + using Lambda = ap::axpr::Lambda; + using Val = ap::kernel_dispatch::Val; + using DispatchCtx = ap::kernel_dispatch::DispatchCtx; + + adt::Result InterpretCtxMaker(const Lambda& ctx_maker_lambda); + + adt::Result InterpretKernelDispatcher( + const Lambda& kernel_dispatch_lambda, const DispatchCtx& dispatch_ctx); +}; + +} // namespace phi diff --git a/paddle/ap/include/paddle/phi/place.h b/paddle/ap/include/paddle/phi/place.h new file mode 100644 index 00000000000000..1a054b9b07bd77 --- /dev/null +++ b/paddle/ap/include/paddle/phi/place.h @@ -0,0 +1,20 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/phi/common/place.h" + +namespace ap::paddle {} diff --git a/paddle/ap/include/paddle/phi/place_method_class.h b/paddle/ap/include/paddle/phi/place_method_class.h new file mode 100644 index 00000000000000..0d0e505e4121e0 --- /dev/null +++ b/paddle/ap/include/paddle/phi/place_method_class.h @@ -0,0 +1,103 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/paddle/phi/place.h" + +namespace ap::paddle { + +inline adt::Result PlaceToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 0); + const auto& str = self.DebugString(); + return str; +} + +inline axpr::TypeImpl> GetPlaceClass() { + static auto cls(axpr::MakeBuiltinClass( + "Place", [&](const auto& DoEach) { DoEach("__str__", &PlaceToString); })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +inline adt::Result CreateUndefinedPlace( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + phi::Place place; + return GetPlaceClass().New(place); +} + +inline adt::Result CreateCPUPlace( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + phi::Place place = phi::CPUPlace(); + return GetPlaceClass().New(place); +} + +inline adt::Result CreateGPUPlace( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(device_id, args.at(0).template TryGet()); + phi::Place place = phi::GPUPlace(device_id); + return GetPlaceClass().New(place); +} + +inline adt::Result CreateGPUPinnedPlace( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + phi::Place place = phi::GPUPinnedPlace(); + return GetPlaceClass().New(place); +} + +inline adt::Result CreateXPUPlace( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(device_id, args.at(0).template TryGet()); + phi::Place place = phi::XPUPlace(device_id); + return GetPlaceClass().New(place); +} + +inline adt::Result CreateIPUPlace( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(device_id, args.at(0).template TryGet()); + phi::Place place = phi::IPUPlace(device_id); + return GetPlaceClass().New(place); +} + +inline adt::Result CreateCustomPlace( + const axpr::Value& self_val, const std::vector& args) { + std::optional place; + if (args.size() == 1) { + ADT_LET_CONST_REF(dev_type, args.at(0).template TryGet()); + place = phi::CustomPlace(dev_type); + } else if (args.size() == 2) { + ADT_LET_CONST_REF(dev_type, args.at(0).template TryGet()); + ADT_LET_CONST_REF(device_id, args.at(1).template TryGet()); + place = phi::CustomPlace(dev_type, device_id); + } else { + return adt::errors::TypeError{std::string() + + "CustomPlace() takes 1 or 2 arguments, but " + + std::to_string(args.size()) + " were given"}; + } + ADT_CHECK(place.has_value()); + return GetPlaceClass().New(place.value()); +} + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/phi/scalar_helper.h b/paddle/ap/include/paddle/phi/scalar_helper.h new file mode 100644 index 00000000000000..69cc0083ffa8e9 --- /dev/null +++ b/paddle/ap/include/paddle/phi/scalar_helper.h @@ -0,0 +1,92 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/data_value.h" +#include "paddle/phi/common/scalar.h" + +namespace ap::paddle { + +struct ScalarHelper { + adt::Result ConvertFromDataType( + const axpr::DataValue& data_val) { + using RetT = adt::Result; + return data_val.Match( + [&](double c) -> RetT { return phi::Scalar(c); }, + [&](float c) -> RetT { return phi::Scalar(c); }, + [&](axpr::float16 c) -> RetT { return phi::Scalar(c); }, + [&](axpr::bfloat16 c) -> RetT { return phi::Scalar(c); }, + [&](int64_t c) -> RetT { return phi::Scalar(c); }, + [&](int32_t c) -> RetT { return phi::Scalar(c); }, + [&](int16_t c) -> RetT { return phi::Scalar(c); }, + [&](int8_t c) -> RetT { return phi::Scalar(c); }, + [&](uint64_t c) -> RetT { return phi::Scalar(c); }, + [&](uint32_t c) -> RetT { return phi::Scalar(c); }, + [&](uint16_t c) -> RetT { return phi::Scalar(c); }, + [&](uint8_t c) -> RetT { return phi::Scalar(c); }, + [&](bool c) -> RetT { return phi::Scalar(c); }, + [&](const axpr::complex64& c) -> RetT { return phi::Scalar(c); }, + [&](const axpr::complex128& c) -> RetT { return phi::Scalar(c); }, + [&](const auto&) -> RetT { + return adt::errors::TypeError{ + std::string() + "ConvertFromDataType(): can not convert from " + + data_val.GetType().Name() + " to phi::Scalar"}; + }); + } + + adt::Result ConvertToDataValue(const phi::Scalar& scalar) { + switch (scalar.dtype()) { + case phi::DataType::FLOAT32: + return axpr::DataValue(scalar.to()); + case phi::DataType::FLOAT64: + return axpr::DataValue(scalar.to()); + case phi::DataType::FLOAT16: + return axpr::DataValue(scalar.to()); + case phi::DataType::BFLOAT16: + return axpr::DataValue(scalar.to()); + case phi::DataType::INT32: + return axpr::DataValue(scalar.to()); + case phi::DataType::INT64: + return axpr::DataValue(scalar.to()); + case phi::DataType::INT16: + return axpr::DataValue(scalar.to()); + case phi::DataType::INT8: + return axpr::DataValue(scalar.to()); + case phi::DataType::UINT64: + return axpr::DataValue(scalar.to()); + case phi::DataType::UINT32: + return axpr::DataValue(scalar.to()); + case phi::DataType::UINT16: + return axpr::DataValue(scalar.to()); + case phi::DataType::UINT8: + return axpr::DataValue(scalar.to()); + case phi::DataType::BOOL: + return axpr::DataValue(scalar.to()); + case phi::DataType::COMPLEX64: + return axpr::DataValue(scalar.to()); + case phi::DataType::COMPLEX128: + return axpr::DataValue(scalar.to()); + default: + std::ostringstream ss; + ss << scalar.dtype(); + return adt::errors::TypeError{std::string() + + "Invalid enum scalar data type `" + + ss.str() + "`."}; + } + } +}; + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/attr_adt_type_id.h b/paddle/ap/include/paddle/pir/attr_adt_type_id.h new file mode 100644 index 00000000000000..7d56612f3b0f8a --- /dev/null +++ b/paddle/ap/include/paddle/pir/attr_adt_type_id.h @@ -0,0 +1,117 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/common/adt_type_id.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/pir/include/core/builtin_attribute.h" +#include "paddle/pir/include/dialect/shape/ir/shape_attribute.h" + +namespace pir { + +class Attribute; +class BoolAttribute; +class Complex64Attribute; +class Complex128Attribute; +class FloatAttribute; +class DoubleAttribute; +class Int32Attribute; +class IndexAttribute; +class Int64Attribute; +class PointerAttribute; +class TypeAttribute; +class StrAttribute; +class ArrayAttribute; +class TensorNameAttribute; + +} // namespace pir + +namespace pir::shape { + +class SymbolAttribute; + +} + +namespace paddle::dialect { + +class KernelAttribute; +class IntArrayAttribute; +class ScalarAttribute; +class DataTypeAttribute; +class PlaceAttribute; +class DataLayoutAttribute; + +} // namespace paddle::dialect + +namespace cinn::dialect { + +class GroupInfoAttribute; +class CINNKernelInfoAttribute; + +} // namespace cinn::dialect + +namespace ap::paddle { + +struct UnclassifiedAttribute { + static const char* name() { return "a_unclassified"; } +}; + +// clang-format off +#define FOR_EACH_PIR_ATTRIBUTE_TYPE(__macro) \ + __macro(pir::BoolAttribute) \ + __macro(pir::Complex64Attribute) \ + __macro(pir::Complex128Attribute) \ + __macro(pir::FloatAttribute) \ + __macro(pir::DoubleAttribute) \ + __macro(pir::Int32Attribute) \ + __macro(pir::IndexAttribute) \ + __macro(pir::Int64Attribute) \ + __macro(pir::PointerAttribute) \ + __macro(pir::TypeAttribute) \ + __macro(pir::StrAttribute) \ + __macro(pir::ArrayAttribute) \ + __macro(pir::TensorNameAttribute) \ + __macro(pir::shape::SymbolAttribute) \ + __macro(::paddle::dialect::KernelAttribute) \ + __macro(::paddle::dialect::IntArrayAttribute) \ + __macro(::paddle::dialect::ScalarAttribute) \ + __macro(::paddle::dialect::DataTypeAttribute) \ + __macro(::paddle::dialect::PlaceAttribute) \ + __macro(::paddle::dialect::DataLayoutAttribute) \ + __macro(::cinn::dialect::GroupInfoAttribute) \ + __macro(::cinn::dialect::CINNKernelInfoAttribute) +// clang-format on + +using AttrAdtTypeIdBase = ::common::AdtBaseTypeId< +#define AS_ATTR_ADT_TYPE_ID_ALTERNATIVE(cls) cls, + FOR_EACH_PIR_ATTRIBUTE_TYPE(AS_ATTR_ADT_TYPE_ID_ALTERNATIVE) +#undef AS_ATTR_ADT_TYPE_ID_ALTERNATIVE + UnclassifiedAttribute>; + +struct AttrAdtTypeId : public AttrAdtTypeIdBase { + using AttrAdtTypeIdBase::AttrAdtTypeIdBase; +}; + +inline AttrAdtTypeId GetAttrAdtTypeId(const pir::Attribute& attr) { +#define RETURN_ATTR_TYPE_ID_IF_MATCH(cls) \ + if (attr.isa()) return ::common::AdtTypeId{}; + FOR_EACH_PIR_ATTRIBUTE_TYPE(RETURN_ATTR_TYPE_ID_IF_MATCH) +#undef RETURN_ATTR_TYPE_ID_IF_MATCH + return ::common::AdtTypeId{}; +} + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/attribute.h b/paddle/ap/include/paddle/pir/attribute.h new file mode 100644 index 00000000000000..527d13e070c69b --- /dev/null +++ b/paddle/ap/include/paddle/pir/attribute.h @@ -0,0 +1,24 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/paddle/pir/attr_adt_type_id.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/pir/include/core/builtin_attribute.h" +#include "paddle/pir/include/dialect/shape/ir/shape_attribute.h" + +namespace ap::paddle {} diff --git a/paddle/ap/include/paddle/pir/attribute_method_class.h b/paddle/ap/include/paddle/pir/attribute_method_class.h new file mode 100644 index 00000000000000..6f9348bbfed869 --- /dev/null +++ b/paddle/ap/include/paddle/pir/attribute_method_class.h @@ -0,0 +1,264 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/paddle/phi/scalar_helper.h" +#include "paddle/ap/include/paddle/pir/attr_adt_type_id.h" +#include "paddle/ap/include/paddle/pir/attribute.h" + +namespace ap::paddle { + +axpr::TypeImpl> GetPirAttributeClass(); + +template +struct MakePirAttributeImpl; + +struct MakePirAttributeImplBoolAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplBoolAttribute {}; + +struct MakePirAttributeImplComplex64Attribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplComplex64Attribute {}; + +struct MakePirAttributeImplComplex128Attribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplComplex128Attribute {}; + +struct MakePirAttributeImplFloatAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplFloatAttribute {}; + +struct MakePirAttributeImplDoubleAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplDoubleAttribute {}; + +struct MakePirAttributeImplInt32Attribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplInt32Attribute {}; + +struct MakePirAttributeImplIndexAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplIndexAttribute {}; + +struct MakePirAttributeImplInt64Attribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplInt64Attribute {}; + +struct MakePirAttributeImplPointerAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplPointerAttribute {}; + +struct MakePirAttributeImplTypeAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplTypeAttribute {}; + +struct MakePirAttributeImplStrAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplStrAttribute {}; + +struct MakePirAttributeImplArrayAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplArrayAttribute {}; + +struct MakePirAttributeImplTensorNameAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplTensorNameAttribute {}; + +struct MakePirAttributeImplSymbolAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplSymbolAttribute {}; + +struct MakePirAttributeImplKernelAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl<::paddle::dialect::KernelAttribute> + : public MakePirAttributeImplKernelAttribute {}; + +struct MakePirAttributeImplIntArrayAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl<::paddle::dialect::IntArrayAttribute> + : public MakePirAttributeImplIntArrayAttribute {}; + +struct MakePirAttributeImplScalarAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl<::paddle::dialect::ScalarAttribute> + : public MakePirAttributeImplScalarAttribute {}; + +struct MakePirAttributeImplDataTypeAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl<::paddle::dialect::DataTypeAttribute> + : public MakePirAttributeImplDataTypeAttribute {}; + +struct MakePirAttributeImplPlaceAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl<::paddle::dialect::PlaceAttribute> + : public MakePirAttributeImplPlaceAttribute {}; + +struct MakePirAttributeImplDataLayoutAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl<::paddle::dialect::DataLayoutAttribute> + : public MakePirAttributeImplDataLayoutAttribute {}; + +struct MakePirAttributeImplGroupInfoAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl<::cinn::dialect::GroupInfoAttribute> + : public MakePirAttributeImplGroupInfoAttribute {}; + +struct MakePirAttributeImplCINNKernelInfoAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl<::cinn::dialect::CINNKernelInfoAttribute> + : public MakePirAttributeImplCINNKernelInfoAttribute {}; + +struct MakePirAttributeImplUnclassifiedAttribute { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirAttributeImpl + : public MakePirAttributeImplUnclassifiedAttribute {}; + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/manual_op.h b/paddle/ap/include/paddle/pir/manual_op.h new file mode 100644 index 00000000000000..75dd12529ab314 --- /dev/null +++ b/paddle/ap/include/paddle/pir/manual_op.h @@ -0,0 +1,142 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/pir/include/core/builder.h" +#include "paddle/pir/include/core/op_base.h" +#include "paddle/pir/include/core/op_trait.h" +#include "paddle/pir/include/core/operation.h" +#include "paddle/pir/include/core/operation_utils.h" +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" + +namespace ap::dialect { + +class IR_API UpSpiderOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "ap_op.up_spider"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value lhs, + pir::Value rhs); + void VerifySig() const {} + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context) { + return true; + } +}; + +class IR_API DownSpiderOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "ap_op.down_spider"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value input); + void VerifySig() const {} + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); +}; + +class IR_API LoadFromRegisterOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "ap_op.load_from_register"; } + static constexpr uint32_t attributes_num = 4; + static const char *attributes_name[attributes_num]; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Type output_type, + const symbol::ShapeOrDataDimExprs &shape_or_data, + const std::string &name, + const std::string ®ister_var_name); + void VerifySig() const {} + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); +}; + +class IR_API StoreToRegisterOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "ap_op.store_to_register"; } + static constexpr uint32_t attributes_num = 2; + static const char *attributes_name[attributes_num]; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value input, + const std::string &name, + const std::string ®ister_var_name); + void VerifySig() const {} + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); +}; + +class IR_API LoadFromGlobalOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "ap_op.load_from_global"; } + static constexpr uint32_t attributes_num = 1; + static const char *attributes_name[attributes_num]; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value input, + const std::string &index_func_unique_id); + void VerifySig() const {} + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); +}; + +class IR_API StoreToGlobalOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "ap_op.store_to_global"; } + static constexpr uint32_t attributes_num = 1; + static const char *attributes_name[attributes_num]; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value var, + pir::Value val, + const std::string &index_func_unique_id); + void VerifySig() const {} + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); +}; + +} // namespace ap::dialect + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::UpSpiderOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::DownSpiderOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::LoadFromRegisterOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::StoreToRegisterOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::LoadFromGlobalOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::StoreToGlobalOp); diff --git a/paddle/ap/include/paddle/pir/op_dialect.h b/paddle/ap/include/paddle/pir/op_dialect.h new file mode 100644 index 00000000000000..b3c6e86e7780ff --- /dev/null +++ b/paddle/ap/include/paddle/pir/op_dialect.h @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/include/core/dialect.h" + +namespace ap { +namespace dialect { + +class OperatorDialect : public ::pir::Dialect { + public: + explicit OperatorDialect(::pir::IrContext* context); + + static const char* name() { return "ap_op"; } + + private: + void initialize(); +}; + +} // namespace dialect +} // namespace ap + +IR_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::OperatorDialect) diff --git a/paddle/ap/include/paddle/pir/packed_ir_op_inner_source_pattern_helper.h b/paddle/ap/include/paddle/pir/packed_ir_op_inner_source_pattern_helper.h new file mode 100644 index 00000000000000..f79498995f11de --- /dev/null +++ b/paddle/ap/include/paddle/pir/packed_ir_op_inner_source_pattern_helper.h @@ -0,0 +1,38 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/ir_match/graph_match_ctx.h" +#include "paddle/ap/include/paddle/pir_node.h" + +namespace ap::drr { + +struct SourcePatternCtx; + +} + +namespace ap::paddle { + +struct PackedIrOp; + +struct PackedIrOpInnerSourcePatternHelper { + adt::Result>> Match( + const PackedIrOp& ir_op, const drr::SourcePatternCtx& src_ptn_ctx); + adt::Result>> Match( + const pir::Block* block, const drr::SourcePatternCtx& src_ptn_ctx); +}; + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/pass.h b/paddle/ap/include/paddle/pir/pass.h new file mode 100644 index 00000000000000..3c561c7fca6db6 --- /dev/null +++ b/paddle/ap/include/paddle/pir/pass.h @@ -0,0 +1,30 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/pir/include/pass/pass.h" + +namespace ap::paddle { + +struct PassImpl { + std::unique_ptr<::pir::Pass> pir_pass; + + bool operator==(const PassImpl& other) const { return this == &other; } +}; + +ADT_DEFINE_RC(Pass, PassImpl); + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/pass_manager.h b/paddle/ap/include/paddle/pir/pass_manager.h new file mode 100644 index 00000000000000..902d3d8764b37c --- /dev/null +++ b/paddle/ap/include/paddle/pir/pass_manager.h @@ -0,0 +1,30 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/pir/include/pass/pass_manager.h" + +namespace ap::paddle { + +struct PassManagerImpl { + std::shared_ptr pir_pass_manager; + + bool operator==(const PassManagerImpl& other) const { return this == &other; } +}; + +ADT_DEFINE_RC(PassManager, PassManagerImpl); + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/pass_manager_method_class.h b/paddle/ap/include/paddle/pir/pass_manager_method_class.h new file mode 100644 index 00000000000000..797ef9ba1ff8b2 --- /dev/null +++ b/paddle/ap/include/paddle/pir/pass_manager_method_class.h @@ -0,0 +1,31 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/paddle/pir/pass.h" +#include "paddle/ap/include/paddle/pir/pass_manager.h" +#include "paddle/ap/include/paddle/pir/program.h" +#include "paddle/pir/include/core/ir_printer.h" + +namespace ap::paddle { + +axpr::TypeImpl> +GetPirPassManagerClass(); + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/pass_method_class.h b/paddle/ap/include/paddle/pir/pass_method_class.h new file mode 100644 index 00000000000000..d5b15850cbb45b --- /dev/null +++ b/paddle/ap/include/paddle/pir/pass_method_class.h @@ -0,0 +1,27 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/paddle/pir/pass.h" +#include "paddle/pir/include/core/ir_printer.h" + +namespace ap::paddle { + +axpr::TypeImpl> GetPirPassClass(); + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/pir.h b/paddle/ap/include/paddle/pir/pir.h new file mode 100644 index 00000000000000..2c6c495e17c6fb --- /dev/null +++ b/paddle/ap/include/paddle/pir/pir.h @@ -0,0 +1,23 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" + +namespace ap::paddle { + +struct Pir : public std::monostate {}; + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/pir_method_class.h b/paddle/ap/include/paddle/pir/pir_method_class.h new file mode 100644 index 00000000000000..95572c35d7fa9b --- /dev/null +++ b/paddle/ap/include/paddle/pir/pir_method_class.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/paddle/phi/place_method_class.h" +#include "paddle/ap/include/paddle/pir/attribute_method_class.h" +#include "paddle/ap/include/paddle/pir/pir.h" +#include "paddle/ap/include/paddle/pir/type_method_class.h" + +namespace ap::paddle { + +void ForceLinkPir(); +axpr::TypeImpl> GetPirClass(); + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/pir_node_matched_src_ptn_ctx_helper.h b/paddle/ap/include/paddle/pir/pir_node_matched_src_ptn_ctx_helper.h new file mode 100644 index 00000000000000..34da34a409596b --- /dev/null +++ b/paddle/ap/include/paddle/pir/pir_node_matched_src_ptn_ctx_helper.h @@ -0,0 +1,62 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/drr/drr_ctx.h" +#include "paddle/ap/include/drr/source_pattern_ctx.h" +#include "paddle/ap/include/ir_match/graph_match_ctx.h" +#include "paddle/ap/include/paddle/pir_node.h" +#include "paddle/ap/include/reified_drr/matched_src_ptn_ctx_helper.h" + +namespace pir { + +class Block; + +} + +namespace ap::paddle { + +struct PirNodeMatchedSrcPtnCtxHelper + : public reified_drr::MatchedSrcPtnCtxHelper { + PirNodeMatchedSrcPtnCtxHelper( + const drr::SourcePatternCtx& src_ptn_ctx, + const ir_match::GraphMatchCtx& match_ctx) + : src_ptn_ctx_(src_ptn_ctx), match_ctx_(match_ctx) {} + + virtual drr::SourcePatternCtx src_ptn_ctx() { return src_ptn_ctx_; } + + adt::Result> + MakeInnerMatchedSrcPtnCtxHelper( + const drr::PackedIrOp& drr_packed_ir_op) override; + + adt::Result VisitNativeIrOpAttr( + const drr::NativeIrOp& drr_native_ir_op, + const std::function(const std::string& attr_name, + const axpr::Value& attr_val)>& + DoEachAttr) override; + + adt::Result GetNativeIrValueType( + const drr::NativeIrValue& native_ir_value) override; + + private: + adt::Result ConvertBlockToSrcPtnCtx( + pir::Block* block, const std::shared_ptr& drr_ctx); + + drr::SourcePatternCtx src_ptn_ctx_; + ir_match::GraphMatchCtx match_ctx_; +}; + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/pir_to_anf_expr_helper.h b/paddle/ap/include/paddle/pir/pir_to_anf_expr_helper.h new file mode 100644 index 00000000000000..dfd9e87d1d7a13 --- /dev/null +++ b/paddle/ap/include/paddle/pir/pir_to_anf_expr_helper.h @@ -0,0 +1,37 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/paddle/pir/attr_adt_type_id.h" +#include "paddle/ap/include/paddle/pir/type_adt_type_id.h" +#include "paddle/ap/include/reified_drr/drr_node_attr_to_anf_expr_helper.h" + +namespace ap::paddle { + +struct PirToAnfExprHelper : public reified_drr::DrrNodeAttrToAnfExprHelper { + adt::Result ConvertTypeToAnfExpr(axpr::LetContext* ctx, + axpr::Value type) override; + adt::Result ConvertAttrToAnfExpr(axpr::LetContext* ctx, + axpr::Value attr) override; + + adt::Result ConvertPirTypeToAnfExpr(axpr::LetContext* ctx, + pir::Type type); + adt::Result ConvertPirAttrToAnfExpr(axpr::LetContext* ctx, + pir::Attribute attr); +}; + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/program.h b/paddle/ap/include/paddle/pir/program.h new file mode 100644 index 00000000000000..58d32a1eb1a335 --- /dev/null +++ b/paddle/ap/include/paddle/pir/program.h @@ -0,0 +1,30 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/pir/include/core/program.h" + +namespace ap::paddle { + +struct ProgramImpl { + std::shared_ptr<::pir::Program> pir_program; + + bool operator==(const ProgramImpl& other) const { return this == &other; } +}; + +ADT_DEFINE_RC(Program, ProgramImpl); + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/program_method_class.h b/paddle/ap/include/paddle/pir/program_method_class.h new file mode 100644 index 00000000000000..aaf56917f8bd3e --- /dev/null +++ b/paddle/ap/include/paddle/pir/program_method_class.h @@ -0,0 +1,27 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/paddle/pir/program.h" +#include "paddle/pir/include/core/ir_printer.h" + +namespace ap::paddle { + +axpr::TypeImpl> GetPirProgramClass(); + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/shape_or_data_method_class.h b/paddle/ap/include/paddle/pir/shape_or_data_method_class.h new file mode 100644 index 00000000000000..8bf23b88836582 --- /dev/null +++ b/paddle/ap/include/paddle/pir/shape_or_data_method_class.h @@ -0,0 +1,33 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/paddle/pir/type.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/include/core/builtin_attribute.h" +#include "paddle/pir/include/dialect/shape/ir/shape_attribute.h" + +namespace ap::paddle { + +axpr::TypeImpl> +GetPirShapeOrDataClass(); + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/type.h b/paddle/ap/include/paddle/pir/type.h new file mode 100644 index 00000000000000..a19b5889a3bafa --- /dev/null +++ b/paddle/ap/include/paddle/pir/type.h @@ -0,0 +1,19 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/paddle/pir/type_adt_type_id.h" + +namespace ap::paddle {} diff --git a/paddle/ap/include/paddle/pir/type_adt_type_id.h b/paddle/ap/include/paddle/pir/type_adt_type_id.h new file mode 100644 index 00000000000000..2a23ef2f2145a2 --- /dev/null +++ b/paddle/ap/include/paddle/pir/type_adt_type_id.h @@ -0,0 +1,109 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/common/adt_type_id.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/include/core/builtin_attribute.h" + +namespace pir { +class Type; +class VectorType; +class DenseTensorType; +class BFloat16Type; +class Float16Type; +class Float32Type; +class Float64Type; +class Int8Type; +class UInt8Type; +class Int16Type; +class Int32Type; +class Int64Type; +class IndexType; +class BoolType; +class Complex64Type; +class Complex128Type; + +} // namespace pir + +namespace paddle::dialect { + +class SelectedRowsType; +class DenseTensorArrayType; +class SparseCooTensorType; +class SparseCsrTensorType; + +} // namespace paddle::dialect + +// clang-format off +#define FOR_EACH_PIR_ALTERNATIVE_TYPE(__macro) \ + __macro(::pir::VectorType) \ + __macro(::pir::DenseTensorType) \ + __macro(::pir::BFloat16Type) \ + __macro(::pir::Float16Type) \ + __macro(::pir::Float32Type) \ + __macro(::pir::Float64Type) \ + __macro(::pir::Int8Type) \ + __macro(::pir::UInt8Type) \ + __macro(::pir::Int16Type) \ + __macro(::pir::Int32Type) \ + __macro(::pir::Int64Type) \ + __macro(::pir::IndexType) \ + __macro(::pir::BoolType) \ + __macro(::pir::Complex64Type) \ + __macro(::pir::Complex128Type) \ + __macro(::paddle::dialect::SelectedRowsType) \ + __macro(::paddle::dialect::DenseTensorArrayType) \ + __macro(::paddle::dialect::SparseCooTensorType) \ + __macro(::paddle::dialect::SparseCsrTensorType) +// clang-format on + +namespace ap::paddle { + +struct NullType { + static const char* name() { return "t_null"; } +}; + +struct UnclassifiedType { + static const char* name() { return "t_unclassified"; } +}; + +using TypeAdtTypeIdBase = + ::common::AdtBaseTypeId; + +struct TypeAdtTypeId : public TypeAdtTypeIdBase { + using TypeAdtTypeIdBase::TypeAdtTypeIdBase; +}; + +inline TypeAdtTypeId GetTypeAdtTypeId(const pir::Type& type) { + if (!type) { + return ::common::AdtTypeId{}; + } +#define RETURN_TYPE_TYPE_ID_IF_MATCH(cls) \ + if (type.isa()) return ::common::AdtTypeId{}; + FOR_EACH_PIR_ALTERNATIVE_TYPE(RETURN_TYPE_TYPE_ID_IF_MATCH) +#undef RETURN_TYPE_TYPE_ID_IF_MATCH + return ::common::AdtTypeId{}; +} + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/type_method_class.h b/paddle/ap/include/paddle/pir/type_method_class.h new file mode 100644 index 00000000000000..8e6f863aa1a6ce --- /dev/null +++ b/paddle/ap/include/paddle/pir/type_method_class.h @@ -0,0 +1,236 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/paddle/pir/type.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/include/core/builtin_attribute.h" +#include "paddle/pir/include/dialect/shape/ir/shape_attribute.h" + +namespace ap::paddle { + +axpr::TypeImpl> GetPirTypeClass(); + +template +struct MakePirTypeImpl; + +struct MakePirTypeImplNullType { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl : public MakePirTypeImplNullType {}; + +struct MakePirTypeImplVectorType { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::VectorType> : public MakePirTypeImplVectorType {}; + +struct MakePirTypeImplDenseTensorType { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::DenseTensorType> + : public MakePirTypeImplDenseTensorType {}; + +struct MakePirTypeImplBFloat16Type { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::BFloat16Type> + : public MakePirTypeImplBFloat16Type {}; + +struct MakePirTypeImplFloat16Type { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::Float16Type> : public MakePirTypeImplFloat16Type { +}; + +struct MakePirTypeImplFloat32Type { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::Float32Type> : public MakePirTypeImplFloat32Type { +}; + +struct MakePirTypeImplFloat64Type { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::Float64Type> : public MakePirTypeImplFloat64Type { +}; + +struct MakePirTypeImplInt8Type { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::Int8Type> : public MakePirTypeImplInt8Type {}; + +struct MakePirTypeImplUInt8Type { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::UInt8Type> : public MakePirTypeImplUInt8Type {}; + +struct MakePirTypeImplInt16Type { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::Int16Type> : public MakePirTypeImplInt16Type {}; + +struct MakePirTypeImplInt32Type { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::Int32Type> : public MakePirTypeImplInt32Type {}; + +struct MakePirTypeImplInt64Type { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::Int64Type> : public MakePirTypeImplInt64Type {}; + +struct MakePirTypeImplIndexType { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::IndexType> : public MakePirTypeImplIndexType {}; + +struct MakePirTypeImplBoolType { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::BoolType> : public MakePirTypeImplBoolType {}; + +struct MakePirTypeImplComplex64Type { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::Complex64Type> + : public MakePirTypeImplComplex64Type {}; + +struct MakePirTypeImplComplex128Type { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::pir::Complex128Type> + : public MakePirTypeImplComplex128Type {}; + +struct MakePirTypeImplSelectedRowsType { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::paddle::dialect::SelectedRowsType> + : public MakePirTypeImplSelectedRowsType {}; + +struct MakePirTypeImplDenseTensorArrayType { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::paddle::dialect::DenseTensorArrayType> + : public MakePirTypeImplDenseTensorArrayType {}; + +struct MakePirTypeImplSparseCooTensorType { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::paddle::dialect::SparseCooTensorType> + : public MakePirTypeImplSparseCooTensorType {}; + +struct MakePirTypeImplSparseCsrTensorType { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl<::paddle::dialect::SparseCsrTensorType> + : public MakePirTypeImplSparseCsrTensorType {}; + +struct MakePirTypeImplUnclassifiedType { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args); + static adt::Result> GetCallArgs( + const axpr::Value& self_val); +}; +template <> +struct MakePirTypeImpl + : public MakePirTypeImplUnclassifiedType {}; + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir_graph_descriptor.h b/paddle/ap/include/paddle/pir_graph_descriptor.h new file mode 100644 index 00000000000000..2c6b5b8e6c31cb --- /dev/null +++ b/paddle/ap/include/paddle/pir_graph_descriptor.h @@ -0,0 +1,591 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/drr/topo_kind.h" +#include "paddle/ap/include/graph/graph_descriptor.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/ir_match/ref_match_ctx.h" +#include "paddle/ap/include/paddle/pir_node.h" +#include "paddle/ap/include/paddle/pir_util.h" + +namespace ap::paddle { + +struct DefaultPirGraphDescriptor { + using NodeT = PirNode; + + NodeT CastToIrOpResult(const pir::OpResult& op_result) const { + if (op_result.owner()->isa()) { + return PackedIrOpResult{op_result}; + } else { + return NativeIrOpResult{op_result}; + } + } + + template + adt::Result VisitUpstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + return node.Match( + [&](const NativeIrValue& impl) -> adt::Result { + if (pir::OpResult::classof(impl.value)) { + return DoEach( + CastToIrOpResult(pir::OpResult::dyn_cast_from(impl.value))); + } + return adt::Ok{}; + }, + [&](const PackedIrValue& impl) -> adt::Result { + // TODO(tianchao): support the following case: + // o.trivial_op0([*t.inputs], [t.op0_output, *t.op0_output1]) + // o.trivial_op1([*.t.op0_output1], [t.op1_output]) + return adt::errors::NotImplementedError{ + "DefaultPirGraphDescriptor::VisitUpstreamNodes does not support " + "PackedIrValue"}; + }, + [&](const RefIrValue& impl) -> adt::Result { + RefIrOpResult ir_op_result{impl.ref_node_info}; + return DoEach(ir_op_result); + }, + [&](const NativeIrOpOperand& impl) -> adt::Result { + NativeIrValue ir_value{impl.op_operand.source()}; + return DoEach(ir_value); + }, + [&](const PackedIrOpOperand& impl) -> adt::Result { + const auto& inputs = GetFusionOpInputValues(impl.fusion_op); + ADT_CHECK(impl.free_tensor_index >= 0); + ADT_CHECK(impl.free_tensor_index < inputs.size()); + NativeIrValue ir_value{inputs.at(impl.free_tensor_index)}; + return DoEach(ir_value); + }, + [&](const RefIrOpOperand& impl) -> adt::Result { + return DoEach(impl.ref_node_info->ir_value); + }, + [&](const NativeIrOp& impl) -> adt::Result { + for (int i = 0; i < impl.op->num_operands(); ++i) { + NativeIrOpOperand ir_op_operand{impl.op->operand(i)}; + ADT_RETURN_IF_ERR(DoEach(ir_op_operand)); + } + return adt::Ok{}; + }, + [&](const PackedIrOp& impl) -> adt::Result { + const auto& inputs = GetFusionOpInputValues(impl.fusion_op); + for (int i = 0; i < inputs.size(); ++i) { + PackedIrOpOperand ir_op_operand{impl.fusion_op, i}; + ADT_RETURN_IF_ERR(DoEach(ir_op_operand)); + } + return adt::Ok{}; + }, + [&](const RefIrOp& impl) -> adt::Result { + RefIrOpOperand ir_op_operand{impl.ref_node_info}; + return DoEach(ir_op_operand); + }, + [&](const NativeIrOpResult& impl) -> adt::Result { + NativeIrOp ir_op{impl.op_result.defining_op()}; + return DoEach(ir_op); + }, + [&](const PackedIrOpResult& impl) -> adt::Result { + auto* op = impl.op_result.defining_op(); + ADT_CHECK(op->isa()); + PackedIrOp ir_op{op->dyn_cast()}; + return DoEach(ir_op); + }, + [&](const RefIrOpResult& impl) -> adt::Result { + RefIrOp ir_op{impl.ref_node_info}; + return DoEach(ir_op); + }); + } + + template + adt::Result VisitDownstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + return node.Match( + [&](const NativeIrValue& impl) -> adt::Result { + for (auto iter = impl.value.use_begin(); iter != impl.value.use_end(); + ++iter) { + auto* user_parent_block = iter->owner()->GetParent(); + ADT_CHECK(user_parent_block != nullptr); + auto* user_parent_op = user_parent_block->GetParentOp(); + if (user_parent_op->isa()) { + auto fusion_op = + user_parent_op->dyn_cast(); + const auto& user_op_inputs = GetFusionOpInputValues(fusion_op); + for (int i = 0; i < user_op_inputs.size(); ++i) { + if (user_op_inputs.at(i) == impl.value) { + PackedIrOpOperand ir_op_operand{fusion_op, i}; + ADT_RETURN_IF_ERR(DoEach(ir_op_operand)); + } + } + } else { + pir::OpOperand op_operand = *iter; + NativeIrOpOperand ir_op_operand{op_operand}; + ADT_RETURN_IF_ERR(DoEach(ir_op_operand)); + } + } + return adt::Ok{}; + }, + [&](const PackedIrValue& impl) -> adt::Result { + // TODO(tianchao): support the following case: + // o.trivial_op0([*t.inputs], [t.op0_output, *t.op0_output1]) + // o.trivial_op1([*.t.op0_output1], [t.op1_output]) + return adt::Ok{}; + }, + [&](const RefIrValue& impl) -> adt::Result { + for (const auto& ir_op_operand : + *impl.ref_node_info->op_operands_subset) { + ADT_RETURN_IF_ERR(DoEach(ir_op_operand)); + } + return adt::Ok{}; + }, + [&](const NativeIrOpOperand& impl) -> adt::Result { + NativeIrOp ir_op{impl.op_operand.owner()}; + return DoEach(ir_op); + }, + [&](const PackedIrOpOperand& impl) -> adt::Result { + PackedIrOp ir_op{impl.fusion_op}; + return DoEach(ir_op); + }, + [&](const RefIrOpOperand& impl) -> adt::Result { + RefIrOp ir_op{impl.ref_node_info}; + return DoEach(ir_op); + }, + [&](const NativeIrOp& impl) -> adt::Result { + for (int i = 0; i < impl.op->num_results(); ++i) { + const auto& value = impl.op->result(i); + ADT_CHECK(pir::OpResult::classof(value)); + NativeIrOpResult ir_op_result{pir::OpResult::dyn_cast_from(value)}; + ADT_RETURN_IF_ERR(DoEach(ir_op_result)); + } + return adt::Ok{}; + }, + [&](const PackedIrOp& impl) -> adt::Result { + for (int i = 0; i < impl.fusion_op->num_results(); ++i) { + const auto& value = impl.fusion_op->result(i); + ADT_CHECK(pir::OpResult::classof(value)); + PackedIrOpResult ir_op_result{pir::OpResult::dyn_cast_from(value)}; + ADT_RETURN_IF_ERR(DoEach(ir_op_result)); + } + return adt::Ok{}; + }, + [&](const RefIrOp& impl) -> adt::Result { + RefIrOpResult ir_op_result{impl.ref_node_info}; + return DoEach(ir_op_result); + }, + [&](const NativeIrOpResult& impl) -> adt::Result { + pir::Value value = impl.op_result; + NativeIrValue ir_value{value}; + return DoEach(ir_value); + }, + [&](const PackedIrOpResult& impl) -> adt::Result { + pir::Value value = impl.op_result; + NativeIrValue ir_value{value}; + return DoEach(ir_value); + }, + [&](const RefIrOpResult& impl) -> adt::Result { + RefIrValue ir_value{impl.ref_node_info}; + return DoEach(ir_value); + }); + } + + adt::Result GetSmallGraphNodeTopoCstr( + const NodeT& node) const { + return graph::SmallGraphNodeTopoCstr{node.node_topo_cstr()}; + } + + adt::Result IgnoredNode(const NodeT& node) const { + return node.Match( + [](const PackedIrValue&) -> adt::Result { return true; }, + [](const auto&) -> adt::Result { return false; }); + } + + adt::Result IsOpNode(const NodeT& node) const { + return node.Match([&](const NativeIrOp&) -> bool { return true; }, + [&](const PackedIrOp&) -> bool { return true; }, + [&](const RefIrOp&) -> bool { return true; }, + [&](const auto&) -> bool { return false; }); + } + + adt::Result IsValueNode(const NodeT& node) const { + return node.Match([&](const NativeIrValue&) -> bool { return true; }, + [&](const PackedIrValue&) -> bool { return true; }, + [&](const RefIrValue&) -> bool { return true; }, + [&](const auto&) -> bool { return false; }); + } + + adt::Result TopoSatisfy( + const NodeT& node, + const graph::SmallGraphNodeTopoCstr& node_topo_cstr) const { + graph::BigGraphNodeTopoCstr bg_node_topo_cstr{node.node_topo_cstr()}; + return bg_node_topo_cstr.TopoSatisfy(node_topo_cstr); + } + + const std::vector& GetFusionOpInputValues( + cinn::dialect::FusionOp fusion_op) const { + auto iter = fusion_op2input_values_.find(fusion_op); + if (iter == fusion_op2input_values_.end()) { + iter = + fusion_op2input_values_ + .emplace(fusion_op, ap::paddle::GetUsedExternalValue(*fusion_op)) + .first; + } + return iter->second; + } + + private: + mutable std::unordered_map> + fusion_op2input_values_; +}; + +struct RefAugmentedPirGraphDescriptor { + using NodeT = PirNode; + using RefNodeInfo = ir_match::RefNodeInfo; + using RefMatchCtx = ir_match::RefMatchCtx; + RefMatchCtx ref_match_ctx; + DefaultPirGraphDescriptor backend_graph; + + template + adt::Result VisitUpstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + using Ok = adt::Result; + return node.Match( + [&](const NativeIrOpOperand& impl) -> Ok { + const auto iter = ref_match_ctx->operand2node_info.find(impl); + if (iter == ref_match_ctx->operand2node_info.end()) { + return backend_graph.VisitUpstreamNodes(node, DoEach); + } + RefIrValue ref_ir_value{iter->second}; + return DoEach(ref_ir_value); + }, + [&](const auto&) -> Ok { + return backend_graph.VisitUpstreamNodes(node, DoEach); + }); + } + + template + adt::Result VisitDownstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + using Ok = adt::Result; + return node.Match( + [&](const NativeIrValue& impl) -> Ok { + const auto iter = ref_match_ctx->value2ref_node_info.find(impl); + if (iter == ref_match_ctx->value2ref_node_info.end()) { + return backend_graph.VisitDownstreamNodes(node, DoEach); + } + for (const auto& ref_node_info : iter->second) { + RefIrOpOperand ir_op_operand{ref_node_info}; + ADT_RETURN_IF_ERR(DoEach(ir_op_operand)); + } + return adt::Ok{}; + }, + [&](const auto&) -> Ok { + return backend_graph.VisitDownstreamNodes(node, DoEach); + }); + } + + adt::Result GetSmallGraphNodeTopoCstr( + const NodeT& node) const { + return backend_graph.GetSmallGraphNodeTopoCstr(node); + } + + adt::Result IgnoredNode(const NodeT& node) const { + return backend_graph.IgnoredNode(node); + } + + adt::Result IsOpNode(const NodeT& node) const { + return backend_graph.IsOpNode(node); + } + + adt::Result TopoSatisfy( + const NodeT& node, + const graph::SmallGraphNodeTopoCstr& node_topo_cstr) const { + return backend_graph.TopoSatisfy(node, node_topo_cstr); + } +}; + +struct AllOperandAndResultPirGraphDescriptor { + using NodeT = PirNode; + + DefaultPirGraphDescriptor backend_graph; + + template + adt::Result VisitUpstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + auto DoEachOpOrValue = [&](const NodeT& upstream) -> adt::Result { + ADT_LET_CONST_REF(is_op_node, backend_graph.IsOpNode(upstream)); + ADT_LET_CONST_REF(is_value_node, backend_graph.IsValueNode(upstream)); + ADT_CHECK(is_op_node || is_value_node); + return backend_graph.VisitUpstreamNodes(upstream, DoEach); + }; + return backend_graph.VisitUpstreamNodes(node, DoEachOpOrValue); + } + + template + adt::Result VisitDownstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + auto DoEachOpOrValue = + [&](const NodeT& downstream) -> adt::Result { + ADT_LET_CONST_REF(is_op_node, backend_graph.IsOpNode(downstream)); + ADT_LET_CONST_REF(is_value_node, backend_graph.IsValueNode(downstream)); + ADT_CHECK(is_op_node || is_value_node); + return backend_graph.VisitDownstreamNodes(downstream, DoEach); + }; + return backend_graph.VisitDownstreamNodes(node, DoEachOpOrValue); + } + + adt::Result GetSmallGraphNodeTopoCstr( + const NodeT& node) const { + return backend_graph.GetSmallGraphNodeTopoCstr(node); + } + + adt::Result IgnoredNode(const NodeT& node) const { + ADT_LET_CONST_REF(is_op_node, backend_graph.IsOpNode(node)); + ADT_LET_CONST_REF(is_value_node, backend_graph.IsValueNode(node)); + if (is_op_node || is_value_node) { + return true; + } + return backend_graph.IgnoredNode(node); + } + + adt::Result IsOpNode(const NodeT& node) const { + return backend_graph.IsOpNode(node); + } + + adt::Result TopoSatisfy( + const NodeT& node, + const graph::SmallGraphNodeTopoCstr& node_topo_cstr) const { + return backend_graph.TopoSatisfy(node, node_topo_cstr); + } +}; + +struct NativeOperandAndResultPirGraphDescriptor { + using NodeT = PirNode; + + AllOperandAndResultPirGraphDescriptor backend_graph; + + template + adt::Result VisitUpstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + ADT_LET_CONST_REF(is_node_native, IsNative(node)); + ADT_CHECK(is_node_native); + auto VisitEachNative = [&](const NodeT& upstream) -> adt::Result { + ADT_LET_CONST_REF(is_upstream_native, IsNative(upstream)); + ADT_CHECK(!is_upstream_native); + return backend_graph.VisitUpstreamNodes(upstream, DoEach); + }; + auto VisitEachPacked = [&](const NodeT& upstream) -> adt::Result { + ADT_LET_CONST_REF(is_upstream_native, IsNative(upstream)); + ADT_CHECK(!is_upstream_native); + return backend_graph.VisitUpstreamNodes(upstream, VisitEachNative); + }; + auto DoEachOperandOrResult = + [&](const NodeT& upstream) -> adt::Result { + ADT_LET_CONST_REF(is_native, IsNative(upstream)); + if (is_native) { + return DoEach(upstream); + } else { + return backend_graph.VisitUpstreamNodes(upstream, VisitEachPacked); + } + }; + return backend_graph.VisitUpstreamNodes(node, DoEachOperandOrResult); + } + + template + adt::Result VisitDownstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + ADT_LET_CONST_REF(is_node_native, IsNative(node)); + ADT_CHECK(is_node_native); + auto VisitEachNative = + [&](const NodeT& downstream) -> adt::Result { + ADT_LET_CONST_REF(is_downstream_native, IsNative(downstream)); + ADT_CHECK(!is_downstream_native); + return backend_graph.VisitDownstreamNodes(downstream, DoEach); + }; + auto VisitEachPacked = + [&](const NodeT& downstream) -> adt::Result { + ADT_LET_CONST_REF(is_downstream_native, IsNative(downstream)); + ADT_CHECK(!is_downstream_native); + return backend_graph.VisitDownstreamNodes(downstream, VisitEachNative); + }; + auto DoEachOperandOrResult = + [&](const NodeT& downstream) -> adt::Result { + ADT_LET_CONST_REF(is_native, IsNative(downstream)); + if (is_native) { + return DoEach(downstream); + } else { + return backend_graph.VisitDownstreamNodes(downstream, VisitEachPacked); + } + }; + return backend_graph.VisitDownstreamNodes(node, DoEachOperandOrResult); + } + + adt::Result GetSmallGraphNodeTopoCstr( + const NodeT& node) const { + return backend_graph.GetSmallGraphNodeTopoCstr(node); + } + + adt::Result IgnoredNode(const NodeT& node) const { + ADT_LET_CONST_REF(is_native, IsNative(node)); + if (!is_native) { + return true; + } + return backend_graph.IgnoredNode(node); + } + + adt::Result IsOpNode(const NodeT& node) const { + return backend_graph.IsOpNode(node); + } + + adt::Result TopoSatisfy( + const NodeT& node, + const graph::SmallGraphNodeTopoCstr& node_topo_cstr) const { + return backend_graph.TopoSatisfy(node, node_topo_cstr); + } + + adt::Result IsNative(const NodeT& node) const { + return node.Match([&](const NativeIrOpOperand&) -> bool { return true; }, + [&](const NativeIrOpResult&) -> bool { return true; }, + [&](const auto&) -> bool { return false; }); + } +}; + +struct BlockBoundPirGraphDescriptor { + using NodeT = PirNode; + + private: + std::function(const NodeT&)> BelongToThisBlockOrNotOp_; + DefaultPirGraphDescriptor backend_graph_; + + public: + explicit BlockBoundPirGraphDescriptor( + const std::function(const NodeT&)>& + BelongToThisBlockOrNotOp) + : BelongToThisBlockOrNotOp_(BelongToThisBlockOrNotOp), backend_graph_{} {} + + template + adt::Result VisitUpstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + using Ok = adt::Result; + return backend_graph_.VisitUpstreamNodes( + node, [&](const NodeT& upstream) -> Ok { + ADT_LET_CONST_REF(belong_to_this_block, + BelongToThisBlockOrNotOp_(upstream)); + if (belong_to_this_block) { + return DoEach(upstream); + } + return adt::Ok{}; + }); + } + + template + adt::Result VisitDownstreamNodes(const NodeT& node, + const DoEachT& DoEach) const { + using Ok = adt::Result; + return node.Match( + [&](const NativeIrValue& impl) -> adt::Result { + for (auto iter = impl.value.use_begin(); iter != impl.value.use_end(); + ++iter) { + ADT_LET_CONST_REF( + belong_to_this_block, + BelongToThisBlockOrNotOp_(NativeIrOp{iter->owner()})); + if (belong_to_this_block) { + pir::OpOperand op_operand = *iter; + NativeIrOpOperand ir_op_operand{op_operand}; + ADT_RETURN_IF_ERR(DoEach(ir_op_operand)); + continue; + } + auto* user_parent_block = iter->owner()->GetParent(); + ADT_CHECK(user_parent_block != nullptr); + auto* user_parent_op = user_parent_block->GetParentOp(); + if (!user_parent_op->isa()) { + continue; + } + auto fusion_op = + user_parent_op->dyn_cast(); + ADT_LET_CONST_REF(parent_belong_to_this_block, + BelongToThisBlockOrNotOp_(PackedIrOp{fusion_op})); + if (!parent_belong_to_this_block) { + continue; + } + const auto& user_op_inputs = + backend_graph_.GetFusionOpInputValues(fusion_op); + for (int i = 0; i < user_op_inputs.size(); ++i) { + if (user_op_inputs.at(i) == impl.value) { + PackedIrOpOperand ir_op_operand{fusion_op, i}; + ADT_RETURN_IF_ERR(DoEach(ir_op_operand)); + } + } + } + return adt::Ok{}; + }, + [&](const auto&) -> Ok { + return backend_graph_.VisitDownstreamNodes( + node, [&](const NodeT& downstream) -> Ok { + ADT_LET_CONST_REF(belong_to_this_block, + BelongToThisBlockOrNotOp_(downstream)); + if (belong_to_this_block) { + return DoEach(downstream); + } + return adt::Ok{}; + }); + }); + } + + adt::Result GetSmallGraphNodeTopoCstr( + const NodeT& node) const { + return backend_graph_.GetSmallGraphNodeTopoCstr(node); + } + + adt::Result IgnoredNode(const NodeT& node) const { + return backend_graph_.IgnoredNode(node); + } + + adt::Result IsOpNode(const NodeT& node) const { + return backend_graph_.IsOpNode(node); + } + + adt::Result TopoSatisfy( + const NodeT& node, + const graph::SmallGraphNodeTopoCstr& node_topo_cstr) const { + return backend_graph_.TopoSatisfy(node, node_topo_cstr); + } +}; + +} // namespace ap::paddle + +namespace ap::graph { + +template <> +struct GraphDescriptor + : public ap::paddle::DefaultPirGraphDescriptor {}; + +template <> +struct GraphDescriptor + : public ap::paddle::RefAugmentedPirGraphDescriptor {}; + +template <> +struct GraphDescriptor + : public ap::paddle::AllOperandAndResultPirGraphDescriptor {}; + +template <> +struct GraphDescriptor + : public ap::paddle::NativeOperandAndResultPirGraphDescriptor {}; + +template <> +struct GraphDescriptor + : public ap::paddle::BlockBoundPirGraphDescriptor { + using ap::paddle::BlockBoundPirGraphDescriptor::BlockBoundPirGraphDescriptor; +}; + +} // namespace ap::graph diff --git a/paddle/ap/include/paddle/pir_node.h b/paddle/ap/include/paddle/pir_node.h new file mode 100644 index 00000000000000..b3e662aec4ab9d --- /dev/null +++ b/paddle/ap/include/paddle/pir_node.h @@ -0,0 +1,415 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/data_type.h" +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/graph/node_topo_cstr.h" +#include "paddle/ap/include/ir_match/ref_match_ctx.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/pir/include/core/op_operand.h" +#include "paddle/pir/include/core/op_result.h" +#include "paddle/pir/include/core/operation.h" +#include "paddle/pir/include/core/value.h" +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" + +namespace ap::paddle { + +template +axpr::TypeImpl> GetNativeIrValueClass(); + +template +axpr::TypeImpl> GetPackedIrValueClass(); + +template +axpr::TypeImpl> GetRefIrValueClass(); + +template +axpr::TypeImpl> GetNativeIrOpClass(); + +template +axpr::TypeImpl> GetPackedIrOpClass(); + +template +axpr::TypeImpl> GetRefIrOpClass(); + +struct NativeIrValue { + pir::Value value; + + template + static axpr::TypeImpl> GetBuiltinClass() { + return GetNativeIrValueClass(); + } + + std::size_t GetHashValue() const { return std::hash()(value); } + + bool operator==(const NativeIrValue& other) const { + return this->value == other.value; + } + + graph::NativeIrValueTopoCstr node_topo_cstr() const { + return graph::NativeIrValueTopoCstr{}; + } + + adt::Result GetDataType() const { + ADT_LET_CONST_REF(type, GetPhiDataType()); + return ap::axpr::GetDataTypeFromPhiDataType(type); + } + + adt::Result*> GetShapeDimExprsPtr() const { + ADT_LET_CONST_REF(shape_or_data, GetShapeOrDataDimExprsPtr()); + return &shape_or_data->shape(); + } + + adt::Result GetShapeOrDataDimExprsPtr() + const { + auto* op = value.defining_op(); + ADT_CHECK(op != nullptr); + auto* program = op->GetParentProgram(); + auto& shape_analysis = ::pir::ShapeAnalysisManager::Instance().Get(program); + const auto& shape_or_data = shape_analysis.GetShapeOrDataForValue(value); + using RetT = adt::Result; + return shape_or_data.Match( + [&](const symbol::TensorShapeOrDataDimExprs& impl) -> RetT { + return &shape_or_data; + }, + [&](const auto&) -> RetT { + return adt::errors::TypeError{ + "GetShapeDimExprsPtr only support TensorShapeOrDataDimExprs."}; + }); + } + + private: + adt::Result GetPhiDataType() const { + ADT_LET_CONST_REF(type, GetPirDataType()); + try { + return ::paddle::dialect::TransToPhiDataType(type); + } catch (const std::exception& e) { + return adt::errors::TypeError{ + "failed to cast from pir data type to phi data type."}; + } + } + + adt::Result GetPirDataType() const { + if (!this->value.type().isa()) { + return adt::errors::NotImplementedError{ + "pir value must be of DenseTensorType"}; + } + const auto dense_tensor_type = + this->value.type().dyn_cast(); + return dense_tensor_type.dtype(); + } +}; + +struct PackedIrValue { + cinn::dialect::FusionOp fusion_op; + bool is_output; + + template + static axpr::TypeImpl> GetBuiltinClass() { + return GetPackedIrValueClass(); + } + + std::size_t GetHashValue() const { + return std::hash()( + static_cast(fusion_op)) ^ + is_output; + } + + bool operator==(const PackedIrValue& other) const { + return this->fusion_op == other.fusion_op && + this->is_output == other.is_output; + } + + graph::PackedIrValueTopoCstr node_topo_cstr() const { + return graph::PackedIrValueTopoCstr{}; + } +}; + +struct NativeIrOpOperand { + pir::OpOperand op_operand; + + std::size_t GetHashValue() const { + return std::hash()(op_operand); + } + + bool operator==(const NativeIrOpOperand& other) const { + return this->op_operand == other.op_operand; + } + + graph::NativeIrOpOperandTopoCstr node_topo_cstr() const { + return graph::NativeIrOpOperandTopoCstr{this->op_operand.index()}; + } +}; + +struct PackedIrOpOperand { + cinn::dialect::FusionOp fusion_op; + std::size_t free_tensor_index; + + std::size_t GetHashValue() const { + return std::hash()( + static_cast(fusion_op)) ^ + free_tensor_index; + } + + bool operator==(const PackedIrOpOperand& other) const { + return this->fusion_op == other.fusion_op && + this->free_tensor_index == other.free_tensor_index; + } + + graph::PackedIrOpOperandTopoCstr node_topo_cstr() const { + return graph::PackedIrOpOperandTopoCstr{}; + } +}; + +struct NativeIrOp { + pir::Operation* op; + + template + static axpr::TypeImpl> GetBuiltinClass() { + return GetNativeIrOpClass(); + } + + std::size_t GetHashValue() const { return std::hash()(op); } + + bool operator==(const NativeIrOp& other) const { + return this->op == other.op; + } + + graph::NativeIrOpTopoCstr node_topo_cstr() const { + return graph::NativeIrOpTopoCstr{this->op->name()}; + } +}; + +struct PackedIrOp { + cinn::dialect::FusionOp fusion_op; + + template + static axpr::TypeImpl> GetBuiltinClass() { + return GetPackedIrOpClass(); + } + + std::size_t GetHashValue() const { + return std::hash()( + static_cast(fusion_op)); + } + + bool operator==(const PackedIrOp& other) const { + return this->fusion_op == other.fusion_op; + } + + graph::PackedIrOpTopoCstr node_topo_cstr() const { + return graph::PackedIrOpTopoCstr{"ap_trivial_fusion_op"}; + } +}; + +struct NativeIrOpResult { + pir::OpResult op_result; + + std::size_t GetHashValue() const { + return std::hash()(op_result); + } + + bool operator==(const NativeIrOpResult& other) const { + return this->op_result == other.op_result; + } + + graph::NativeIrOpResultTopoCstr node_topo_cstr() const { + return graph::NativeIrOpResultTopoCstr{this->op_result.index()}; + } +}; + +struct PackedIrOpResult { + pir::OpResult op_result; + + std::size_t GetHashValue() const { + return std::hash()(op_result); + } + + bool operator==(const PackedIrOpResult& other) const { + return this->op_result == other.op_result; + } + + graph::PackedIrOpResultTopoCstr node_topo_cstr() const { + return graph::PackedIrOpResultTopoCstr{}; + } +}; + +} // namespace ap::paddle + +namespace std { + +template <> +struct hash { + std::size_t operator()(const ap::paddle::NativeIrValue& node) const { + return node.GetHashValue(); + } +}; + +template <> +struct hash { + std::size_t operator()(const ap::paddle::NativeIrOpOperand& node) const { + return node.GetHashValue(); + } +}; + +} // namespace std + +namespace ap::paddle { + +using RefNodeInfo = ir_match::RefNodeInfo; + +struct RefIrValue { + RefNodeInfo ref_node_info; + + template + static axpr::TypeImpl> GetBuiltinClass() { + return GetRefIrValueClass(); + } + + std::size_t GetHashValue() const { + return std::hash()(ref_node_info); + } + + bool operator==(const RefIrValue& other) const { + return this->ref_node_info == other.ref_node_info; + } + + adt::Result GetOwnerNativeIrValue() const { + return this->ref_node_info->ir_value; + } + + graph::RefIrValueTopoCstr node_topo_cstr() const { + return graph::RefIrValueTopoCstr{}; + } +}; + +struct RefIrOpOperand { + RefNodeInfo ref_node_info; + + std::size_t GetHashValue() const { + return std::hash()(ref_node_info); + } + + bool operator==(const RefIrOpOperand& other) const { + return this->ref_node_info == other.ref_node_info; + } + + graph::RefIrOpOperandTopoCstr node_topo_cstr() const { + return graph::RefIrOpOperandTopoCstr{}; + } +}; + +struct RefIrOp { + RefNodeInfo ref_node_info; + + template + static axpr::TypeImpl> GetBuiltinClass() { + return GetRefIrOpClass(); + } + + std::size_t GetHashValue() const { + return std::hash()(ref_node_info); + } + + bool operator==(const RefIrOp& other) const { + return this->ref_node_info == other.ref_node_info; + } + + graph::RefIrOpTopoCstr node_topo_cstr() const { + return graph::RefIrOpTopoCstr{}; + } +}; + +struct RefIrOpResult { + RefNodeInfo ref_node_info; + + std::size_t GetHashValue() const { + return std::hash()(ref_node_info); + } + + bool operator==(const RefIrOpResult& other) const { + return this->ref_node_info == other.ref_node_info; + } + + graph::RefIrOpResultTopoCstr node_topo_cstr() const { + return graph::RefIrOpResultTopoCstr{}; + } +}; + +using PirNodeImpl = std::variant; + +struct PirNode : public PirNodeImpl { + using PirNodeImpl::PirNodeImpl; + ADT_DEFINE_VARIANT_METHODS(PirNodeImpl); + + using dim_expr_type = ::symbol::DimExpr; + using native_op_type = NativeIrOp; + using packed_op_type = PackedIrOp; + using ref_op_type = RefIrOp; + using native_value_type = NativeIrValue; + using packed_value_type = PackedIrValue; + using ref_value_type = RefIrValue; + using native_op_operand_type = NativeIrOpOperand; + + std::size_t GetHashValue() const { + return Match([](const auto& impl) { return impl.GetHashValue(); }); + } + + graph::NodeTopoCstr node_topo_cstr() const { + return Match([](const auto& impl) -> graph::NodeTopoCstr { + return impl.node_topo_cstr(); + }); + } + + static adt::Result GetOpNameFromDrrPackedOpName( + const std::string& drr_packed_op_name) { + if (drr_packed_op_name == "ap_trivial_fusion_op") { + return "cinn_op.fusion"; + } + return adt::errors::KeyError{ + std::string() + "no pir op name matched to drr packed op name: '" + + drr_packed_op_name + "'"}; + } +}; + +} // namespace ap::paddle + +namespace std { + +template <> +struct hash { + std::size_t operator()(const ap::paddle::PirNode& node) const { + return node.GetHashValue(); + } +}; + +} // namespace std diff --git a/paddle/ap/include/paddle/pir_node_descriptor.h b/paddle/ap/include/paddle/pir_node_descriptor.h new file mode 100644 index 00000000000000..2302755a379a47 --- /dev/null +++ b/paddle/ap/include/paddle/pir_node_descriptor.h @@ -0,0 +1,213 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h" +#include "paddle/ap/include/drr/src_ptn_packed_ir_op_declare_data.h" +#include "paddle/ap/include/graph/node_descriptor.h" +#include "paddle/ap/include/ir_match/ref_match_ctx.h" +#include "paddle/ap/include/paddle/pir/packed_ir_op_inner_source_pattern_helper.h" +#include "paddle/ap/include/paddle/pir_node.h" + +namespace ap::paddle { + +struct PirNodeDescriptor { + using RefNodeInfo = ir_match::RefNodeInfo; + + std::string DebugId(const PirNode& node) const { + return node.Match( + [&](const NativeIrValue& ir_value) -> std::string { + if (ir_value.value.defining_op() == nullptr) { + return std::to_string( + reinterpret_cast(ir_value.value.impl())); + } else { + const auto* op = ir_value.value.defining_op(); + const auto& op_debug_id = GetOpDebugId(op); + for (int i = 0; i < op->num_results(); ++i) { + if (op->result(i) == ir_value.value) { + return op_debug_id + "_out_" + std::to_string(i); + } + } + return op_debug_id + "_error_output"; + } + }, + [&](const PackedIrValue& ir_value) -> std::string { + pir::Operation* op = ir_value.fusion_op; + const auto& op_debug_id = GetOpDebugId(op); + if (ir_value.is_output) { + return op_debug_id + "_packed_out"; + } else { + return op_debug_id + "_packed_in"; + } + }, + [&](const NativeIrOpOperand& ir_op_operand) -> std::string { + const auto& operand = ir_op_operand.op_operand; + const auto& op_debug_id = GetOpDebugId(operand.owner()); + return op_debug_id + "_operand_" + std::to_string(operand.index()); + }, + [&](const PackedIrOpOperand& ir_op_operand) -> std::string { + pir::Operation* op = ir_op_operand.fusion_op; + const auto& op_debug_id = GetOpDebugId(op); + std::size_t index = ir_op_operand.free_tensor_index; + return op_debug_id + "_packed_operand_" + std::to_string(index); + }, + [&](const NativeIrOp& ir_op) -> std::string { + return GetOpDebugId(ir_op.op); + }, + [&](const PackedIrOp& ir_op) -> std::string { + pir::Operation* op = ir_op.fusion_op; + return GetOpDebugId(op); + }, + [&](const NativeIrOpResult& ir_op_result) -> std::string { + pir::Operation* op = ir_op_result.op_result.owner(); + const auto& op_debug_id = GetOpDebugId(op); + std::size_t index = ir_op_result.op_result.index(); + return op_debug_id + "_result_" + std::to_string(index); + }, + [&](const PackedIrOpResult& ir_op_result) -> std::string { + pir::Operation* op = ir_op_result.op_result.owner(); + const auto& op_debug_id = GetOpDebugId(op); + std::size_t index = ir_op_result.op_result.index(); + return op_debug_id + "_packed_result_" + std::to_string(index); + }, + [&](const RefIrValue& impl) -> std::string { + return std::string() + "RefIrValue(" + + GetRefNodeInfoDebugString(impl.ref_node_info) + ")"; + }, + [&](const RefIrOpOperand& impl) -> std::string { + return std::string() + "RefIrOpOperand(" + + GetRefNodeInfoDebugString(impl.ref_node_info) + ")"; + }, + [&](const RefIrOp& impl) -> std::string { + return std::string() + "RefIrOp(" + + GetRefNodeInfoDebugString(impl.ref_node_info) + ")"; + }, + [&](const RefIrOpResult& impl) -> std::string { + return std::string() + "RefIrOpResult(" + + GetRefNodeInfoDebugString(impl.ref_node_info) + ")"; + }); + } + + std::string GetRefNodeInfoDebugString( + const RefNodeInfo& ref_node_info) const { + std::ostringstream ss; + ss << DebugId(ref_node_info->ir_value); + ss << "=>["; + int i = 0; + for (const auto& op_operand : *ref_node_info->op_operands_subset) { + if (i++ > 0) { + ss << ","; + } + ss << "(" << DebugId(op_operand) << ")"; + } + ss << "]"; + return ss.str(); + } + + std::string GetOpDebugId(const pir::Operation* op) const { + return op->name() + "_" + std::to_string(op->id()); + } + + adt::Result AttrsSatisfyIfBothAreOpsOrValues( + const PirNode& node, const graph::Node& drr_graph_node) { + ADT_LET_CONST_REF(drr_node, drr_graph_node.Get()); + using RetT = adt::Result; + auto pattern_match = ::common::Overloaded{ + [&](const NativeIrValue& pir_value, + const drr::NativeIrValue& drr_value) -> RetT { + return ValueAttrsSatisfy(pir_value, drr_value); + }, + [&](const NativeIrOp& pir_op, const drr::NativeIrOp& drr_op) + -> RetT { return NativeOpAttrsSatisfy(pir_op, drr_op); }, + [&](const PackedIrOp& pir_op, const drr::PackedIrOp& drr_op) + -> RetT { return PackedOpAttrsSatisfy(pir_op, drr_op); }, + [&](const auto& lhs, const auto& rhs) -> RetT { return true; }}; + return std::visit(pattern_match, node.variant(), drr_node.variant()); + } + + adt::Result ValueAttrsSatisfy( + const NativeIrValue& pir_value, + const drr::NativeIrValue& drr_value) { + ADT_LET_CONST_REF(opt_type, + drr::OpTensorPatternCtxHelper{}.GetOptType(drr_value)); + if (!opt_type.has_value()) { + return true; + } + ADT_LET_CONST_REF(type, opt_type.value().template CastTo()); + return type == pir_value.value.type(); + } + + adt::Result NativeOpAttrsSatisfy( + const NativeIrOp& pir_op, const drr::NativeIrOp& drr_op) { + if (drr_op->op_declare->attr_map->storage.empty()) { + return true; + } + for (const auto& [attr_name, attr_val] : + drr_op->op_declare->attr_map->storage) { + const auto& iter = pir_op.op->attributes().find(attr_name); + if (iter == pir_op.op->attributes().end()) { + continue; + } + const auto& pir_attr_val = iter->second; + ADT_LET_CONST_REF(drr_attr_val, + attr_val.template CastTo()); + if (pir_attr_val != drr_attr_val) { + return false; + } + } + return true; + } + + adt::Result PackedOpAttrsSatisfy( + const PackedIrOp& pir_op, const drr::PackedIrOp& drr_op) { + ADT_LET_CONST_REF(inner_source_pattern_satisfy, + PackedOpInnerSourcePatternSatisfy(pir_op, drr_op)); + if (!inner_source_pattern_satisfy) { + return false; + } + return true; + } + + adt::Result PackedOpInnerSourcePatternSatisfy( + const PackedIrOp& pir_op, const drr::PackedIrOp& drr_op) { + ADT_CHECK(drr_op->op_declare->data.has_value()); + auto* raw_data_ptr = drr_op->op_declare->data.value().get(); + auto* data_ptr = + dynamic_cast(raw_data_ptr); + ADT_CHECK(data_ptr != nullptr); + if (!data_ptr->inner_source_pattern_func.has_value()) { + ADT_CHECK(!data_ptr->inner_source_pattern_ctx.has_value()); + return true; + } + ADT_CHECK(data_ptr->inner_source_pattern_ctx.has_value()); + PackedIrOpInnerSourcePatternHelper helper{}; + ADT_LET_CONST_REF( + opt_match_ctx, + helper.Match(pir_op, data_ptr->inner_source_pattern_ctx.value())); + return opt_match_ctx.has_value(); + } +}; + +} // namespace ap::paddle + +namespace ap::graph { + +template <> +struct NodeDescriptor + : public ap::paddle::PirNodeDescriptor {}; + +} // namespace ap::graph diff --git a/paddle/ap/include/paddle/pir_node_helper.h b/paddle/ap/include/paddle/pir_node_helper.h new file mode 100644 index 00000000000000..59b2ea3788bb5f --- /dev/null +++ b/paddle/ap/include/paddle/pir_node_helper.h @@ -0,0 +1,86 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/paddle/pir_node.h" + +namespace ap::paddle { + +struct PirNodeHelper { + using This = PirNodeHelper; + + adt::Result CastFromAxprValue(const axpr::Value& axpr_val) { + using RetT = adt::Result; + return axpr_val.Match( + [&](const axpr::BuiltinClassInstance& instance) -> RetT { + return CastInstanceToPirNode(instance); + }, + [&](const auto&) -> RetT { + return adt::errors::TypeError{ + std::string() + "PirNodeHelper::CastFromAxprValue() failed"}; + }); + } + + private: + using AxprInstanceToPirNodeConverter = + adt::Result (*)(const axpr::BuiltinClassInstance&); + using AxprInstanceToPirNodeMap = + std::map; + + adt::Result CastInstanceToPirNode( + const axpr::BuiltinClassInstance& instance) { + const AxprInstanceToPirNodeMap& map = GetAxprInstanceToPirNodeMap(); + const auto& iter = map.find(instance.instance.type()); + if (iter == map.end()) { + return adt::errors::TypeError{ + "PirNodeHelper::CastInstanceToPirNode failed"}; + } else { + return iter->second(instance); + } + } + + const AxprInstanceToPirNodeMap& GetAxprInstanceToPirNodeMap() { + static const AxprInstanceToPirNodeMap map(MakeAxprInstanceToPirNodeMap()); + return map; + } + + AxprInstanceToPirNodeMap MakeAxprInstanceToPirNodeMap() { + AxprInstanceToPirNodeMap map; + InsertEntries(&map); + return map; + } + + template + void InsertEntries(AxprInstanceToPirNodeMap* map) { + if constexpr (start_idx >= std::variant_size_v) { + return; + } else { + using Impl = typename std::variant_alternative_t; + (*map)[typeid(Impl)] = &This::template ConvertAxprInstanceToPirNode; + InsertEntries(map); + } + } + + template + static adt::Result ConvertAxprInstanceToPirNode( + const axpr::BuiltinClassInstance& instance) { + return PirNode{std::any_cast(instance.instance)}; + } +}; + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir_node_method_class.h b/paddle/ap/include/paddle/pir_node_method_class.h new file mode 100644 index 00000000000000..e1efac8db5d6cb --- /dev/null +++ b/paddle/ap/include/paddle/pir_node_method_class.h @@ -0,0 +1,388 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/dim_expr_method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/paddle/pir/attribute_method_class.h" +#include "paddle/ap/include/paddle/pir/shape_or_data_method_class.h" +#include "paddle/ap/include/paddle/pir/type_method_class.h" + +namespace ap::paddle { + +template +struct NativeIrValueMethodClass { + using This = NativeIrValueMethodClass; + using Self = NativeIrValue; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const auto* ptr = self.value.impl(); + ss << ""; + return ss.str(); + } + + static adt::Result Hash(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return static_cast(std::hash()(self)); + } + + static adt::Result GetAttr(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + const auto& attr_name_val = args.at(0); + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + if (attr_name == "dtype") { + return This{}.GetDataType(self); + } + if (attr_name == "type") { + return GetPirTypeClass().New(self.value.type()); + } + return adt::errors::TypeError{std::string() + + "NativeIrValue instance has no attribute '" + + attr_name + "'."}; + } + + static adt::Result GetSymbolicShapeOrData( + const ValueT& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 0); + ADT_LET_CONST_REF(shape_or_data_ptr, self.GetShapeOrDataDimExprsPtr()); + return ap::paddle::GetPirShapeOrDataClass().New(*shape_or_data_ptr); + } + + static adt::Result SymbolicShapeToList( + const ValueT& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 0); + return This{}.GetShape(self); + } + + adt::Result GetShape(const Self& self) { + ADT_LET_CONST_REF(shape_ptr, self.GetShapeDimExprsPtr()); + adt::List lst; + lst->reserve(shape_ptr->size()); + for (const auto& dim_expr : *shape_ptr) { + axpr::BuiltinClassInstance instance{ + axpr::GetDimExprClass(), dim_expr}; + lst->emplace_back(instance); + } + return lst; + } + + adt::Result GetDataType(const Self& self) { + ADT_LET_CONST_REF(dtype, self.GetDataType()); + return dtype; + } +}; + +template +axpr::TypeImpl> GetNativeIrValueClass() { + using ImplMethods = NativeIrValueMethodClass; + static auto cls( + axpr::MakeBuiltinClass("NativeIrValue", [&](const auto& Yield) { + Yield("__getattr__", &ImplMethods::GetAttr); + Yield("__str__", &ImplMethods::ToString); + Yield("__hash__", &ImplMethods::Hash); + Yield("symbolic_shape_to_list", &ImplMethods::SymbolicShapeToList); + Yield("get_symbolic_shape_or_data", + &ImplMethods::GetSymbolicShapeOrData); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +template +struct PackedIrValueMethodClass { + using This = PackedIrValueMethodClass; + using Self = PackedIrValue; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const pir::Operation* ptr = self.fusion_op; + std::ostringstream ss; + ss << ""; + return ss.str(); + } + + static adt::Result Hash(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const pir::Operation* ptr = self.fusion_op; + return reinterpret_cast(ptr); + } +}; + +template +axpr::TypeImpl> GetPackedIrValueClass() { + using ImplMethods = PackedIrValueMethodClass; + static auto cls( + axpr::MakeBuiltinClass("PackedIrValue", [&](const auto& Yield) { + Yield("__str__", &ImplMethods::ToString); + Yield("__hash__", &ImplMethods::Hash); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +template +struct RefIrValueMethodClass { + using This = RefIrValueMethodClass; + using Self = RefIrValue; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const auto* ptr = self.ref_node_info.__adt_rc_shared_ptr_raw_ptr(); + ss << ""; + return ss.str(); + } + + static adt::Result Hash(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return reinterpret_cast( + self.ref_node_info.__adt_rc_shared_ptr_raw_ptr()); + } + + static adt::Result GetAttr(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + const auto& attr_name_val = args.at(0); + ADT_LET_CONST_REF(attr_name, attr_name_val.template TryGet()); + if (attr_name == "dtype") { + return This{}.GetDataType(self); + } + return adt::errors::TypeError{std::string() + + "NativeIrValue instance has no attribute '" + + attr_name + "'."}; + } + + static adt::Result SymbolicShapeToList( + const ValueT& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 0); + return This{}.GetShape(self); + } + + adt::Result GetShape(const Self& self) { + ADT_LET_CONST_REF(ir_value, self.GetOwnerNativeIrValue()); + ADT_LET_CONST_REF(shape_ptr, ir_value.GetShapeDimExprsPtr()); + adt::List lst; + lst->reserve(shape_ptr->size()); + for (const auto& dim_expr : *shape_ptr) { + axpr::BuiltinClassInstance instance{ + axpr::GetDimExprClass(), dim_expr}; + lst->emplace_back(instance); + } + return lst; + } + + adt::Result GetDataType(const Self& self) { + ADT_LET_CONST_REF(ir_value, self.GetOwnerNativeIrValue()); + ADT_LET_CONST_REF(dtype, ir_value.GetDataType()); + return dtype; + } +}; + +template +axpr::TypeImpl> GetRefIrValueClass() { + using ImplMethods = RefIrValueMethodClass; + static auto cls( + axpr::MakeBuiltinClass("RefIrValue", [&](const auto& Yield) { + Yield("__getattr__", &ImplMethods::GetAttr); + Yield("__str__", &ImplMethods::ToString); + Yield("__hash__", &ImplMethods::Hash); + Yield("symbolic_shape_to_list", &ImplMethods::SymbolicShapeToList); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +template +struct NativeIrOpMethodClass { + using This = NativeIrOpMethodClass; + using Self = NativeIrOp; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const auto* ptr = self.op; + ss << ""; + return ss.str(); + } + + static adt::Result Hash(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const pir::Operation* ptr = self.op; + return reinterpret_cast(ptr); + } + + static adt::Result GetAttr(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(attr_name, args.at(0).template CastTo()); + const pir::Operation* ptr = self.op; + const auto& attrs = ptr->attributes(); + const auto& iter = attrs.find(attr_name); + if (iter == attrs.end()) { + return adt::errors::KeyError{ + "NativeIrOp.__getattr__() failed. can not found attribute '" + + attr_name + "'"}; + } + return GetPirAttributeClass().New(iter->second); + } + + static adt::Result NumOperands(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 0) << adt::errors::TypeError{ + std::string() + "NativeIrOp.num_operands() takes 0 arguments, but " + + std::to_string(args.size()) + " were given"}; + const pir::Operation* op = self.op; + int64_t num_operands = op->num_operands(); + return num_operands; + } + + static adt::Result OperandSource(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "NativeIrOp.operand_source() takes 1 argument, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(i, args.at(0).template CastTo()); + const pir::Operation* op = self.op; + ADT_CHECK(i >= 0); + ADT_CHECK(i < op->num_operands()); + pir::Value value = op->operand_source(i); + return GetNativeIrValueClass().New(NativeIrValue{value}); + } + + static adt::Result NumResults(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 0) << adt::errors::TypeError{ + std::string() + "NativeIrOp.num_results() takes 0 arguments, but " + + std::to_string(args.size()) + " were given"}; + const pir::Operation* op = self.op; + int64_t num_results = op->num_results(); + return num_results; + } + + static adt::Result Result(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "NativeIrOp.result() takes 1 argument, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(i, args.at(0).template CastTo()); + const pir::Operation* op = self.op; + ADT_CHECK(i >= 0); + ADT_CHECK(i < op->num_results()); + pir::Value value = op->result(i); + return GetNativeIrValueClass().New(NativeIrValue{value}); + } +}; + +template +axpr::TypeImpl> GetNativeIrOpClass() { + using Impl = NativeIrOpMethodClass; + static auto cls( + axpr::MakeBuiltinClass("NativeIrOp", [&](const auto& Yield) { + Yield("__str__", &Impl::ToString); + Yield("__hash__", &Impl::Hash); + Yield("__getattr__", &Impl::GetAttr); + Yield("num_operands", &Impl::NumOperands); + Yield("operand_source", &Impl::OperandSource); + Yield("num_results", &Impl::NumResults); + Yield("result", &Impl::Result); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +template +struct PackedIrOpMethodClass { + using This = PackedIrOpMethodClass; + using Self = PackedIrOp; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const pir::Operation* ptr = self.fusion_op; + ss << ""; + return ss.str(); + } + + static adt::Result Hash(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const pir::Operation* ptr = self.fusion_op; + return reinterpret_cast(ptr); + } +}; + +template +axpr::TypeImpl> GetPackedIrOpClass() { + using Impl = PackedIrOpMethodClass; + static auto cls( + axpr::MakeBuiltinClass("PackedIrOp", [&](const auto& Yield) { + Yield("__str__", &Impl::ToString); + Yield("__hash__", &Impl::Hash); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +template +struct RefIrOpMethodClass { + using This = RefIrOpMethodClass; + using Self = RefIrOp; + + static adt::Result ToString(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const auto* ptr = self.ref_node_info.__adt_rc_shared_ptr_raw_ptr(); + ss << ""; + return ss.str(); + } + + static adt::Result Hash(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return reinterpret_cast( + self.ref_node_info.__adt_rc_shared_ptr_raw_ptr()); + } +}; + +template +axpr::TypeImpl> GetRefIrOpClass() { + using Impl = RefIrOpMethodClass; + static auto cls( + axpr::MakeBuiltinClass("RefIrOp", [&](const auto& Yield) { + Yield("__str__", &Impl::ToString); + Yield("__hash__", &Impl::Hash); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir_util.h b/paddle/ap/include/paddle/pir_util.h new file mode 100644 index 00000000000000..f0c6b9cf1c259c --- /dev/null +++ b/paddle/ap/include/paddle/pir_util.h @@ -0,0 +1,63 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace ap::paddle { + +struct FreeVarHelper { + void GetUsedExternalValueImpl( + std::unordered_set& defined_values, // NOLINT + std::vector& used_values, // NOLINT + const pir::Operation& op) { + for (size_t index = 0; index < op.num_operands(); ++index) { + pir::Value value = op.operand_source(index); + if (defined_values.find(value) == defined_values.end()) { + used_values.push_back(value); + defined_values.insert(value); + } + } + for (auto& region : op) { + for (auto& block : region) { + for (auto value : block.args()) { + defined_values.insert(value); + } + for (const auto& [_, value] : block.kwargs()) { + defined_values.insert(value); + } + } + for (auto& block : region) { + for (auto& inner_op : block) { + GetUsedExternalValueImpl(defined_values, used_values, inner_op); + } + } + } + for (size_t index = 0; index < op.num_results(); ++index) { + defined_values.insert(op.result(index)); + } + } + + std::vector GetUsedExternalValue(const pir::Operation& op) { + std::unordered_set defined_values{nullptr}; + std::vector used_values; + GetUsedExternalValueImpl(defined_values, used_values, op); + return used_values; + } +}; + +inline std::vector GetUsedExternalValue(const pir::Operation& op) { + return FreeVarHelper{}.GetUsedExternalValue(op); +} + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/std_vector_meta_tensor_ptr_ptr_method_class.h b/paddle/ap/include/paddle/std_vector_meta_tensor_ptr_ptr_method_class.h new file mode 100644 index 00000000000000..df384ef82dd055 --- /dev/null +++ b/paddle/ap/include/paddle/std_vector_meta_tensor_ptr_ptr_method_class.h @@ -0,0 +1,79 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/paddle/meta_tensor_ptr_method_class.h" + +namespace ap::paddle { + +struct StdVectorMetaTensorPtrPtrMethodClass { + using This = StdVectorMetaTensorPtrPtrMethodClass; + using Self = std::vector*; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self; + ss << ""; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return reinterpret_cast(self); + } + + static adt::Result GetItem( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + const auto& idx_val = args.at(0); + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_LET_CONST_REF(idx, idx_val.template CastTo()) + << adt::errors::TypeError{std::string() + + "vector indices must be integers, not " + + axpr::GetTypeName(idx_val)}; + int64_t index = idx; + if (index < 0) { + index += self->size(); + } + if (index >= 0 && index < self->size()) { + return CastItem(self->at(index)); + } + return adt::errors::IndexError{"vector index out of range"}; + } + + static adt::Result CastItem(const MetaTensorPtr& elem) { + return GetMetaTensorPtrClass().New(elem); + } +}; + +inline axpr::TypeImpl> +GetStdVectorMetaTensorPtrPtrClass() { + using Impl = StdVectorMetaTensorPtrPtrMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "StdVectorMetaTensorPtrPtr", [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__getitem__", &Impl::GetItem); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::paddle diff --git a/paddle/ap/include/preprocessor/preprocessor.h b/paddle/ap/include/preprocessor/preprocessor.h new file mode 100644 index 00000000000000..42b51d99129f1f --- /dev/null +++ b/paddle/ap/include/preprocessor/preprocessor.h @@ -0,0 +1,18 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#define AP_CONCAT(a, b) AP_CONCAT_I(a, b) +#define AP_CONCAT_I(a, b) a##b diff --git a/paddle/ap/include/registry/abstract_drr_pass_registry_item.h b/paddle/ap/include/registry/abstract_drr_pass_registry_item.h new file mode 100644 index 00000000000000..73fc9615ff8602 --- /dev/null +++ b/paddle/ap/include/registry/abstract_drr_pass_registry_item.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/serializable_value.h" + +namespace ap::registry { + +struct AbstractDrrPassRegistryItemImpl { + std::string abstract_drr_pass_name; + int64_t nice; + axpr::ClassAttrs cls; +}; + +ADT_DEFINE_RC(AbstractDrrPassRegistryItem, AbstractDrrPassRegistryItemImpl); + +} // namespace ap::registry diff --git a/paddle/ap/include/registry/access_topo_drr_pass_registry_item.h b/paddle/ap/include/registry/access_topo_drr_pass_registry_item.h new file mode 100644 index 00000000000000..3b4018bead50b9 --- /dev/null +++ b/paddle/ap/include/registry/access_topo_drr_pass_registry_item.h @@ -0,0 +1,33 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/serializable_value.h" + +namespace ap::registry { + +struct AccessTopoDrrPassRegistryItemImpl { + std::string access_topo_drr_pass_name; + std::string pass_tag_name; + int64_t nice; + axpr::ClassAttrs cls; +}; + +ADT_DEFINE_RC(AccessTopoDrrPassRegistryItem, AccessTopoDrrPassRegistryItemImpl); + +} // namespace ap::registry diff --git a/paddle/ap/include/registry/builtin_frame_util.h b/paddle/ap/include/registry/builtin_frame_util.h new file mode 100644 index 00000000000000..2255db278710ec --- /dev/null +++ b/paddle/ap/include/registry/builtin_frame_util.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_frame_util.h" +#include "paddle/ap/include/registry/registry_class.h" + +namespace ap::registry { + +template +void VisitEachBuiltinFrameAttr(const DoEachT& DoEach) { + const auto& registry_class = MakeRegistryClass(); + DoEach(registry_class.Name(), ValueT{registry_class}); +} + +template +axpr::AttrMap MakeBuiltinFrameAttrMap() { + axpr::AttrMap attr_map; + auto Insert = [&](const std::string& k, const ValueT& v) { + attr_map->Set(k, v); + }; + axpr::VisitEachBuiltinFrameAttr(Insert); + VisitEachBuiltinFrameAttr(Insert); + return attr_map; +} + +} // namespace ap::registry diff --git a/paddle/ap/include/registry/classic_drr_pass_registry_item.h b/paddle/ap/include/registry/classic_drr_pass_registry_item.h new file mode 100644 index 00000000000000..e436bb91bc88ea --- /dev/null +++ b/paddle/ap/include/registry/classic_drr_pass_registry_item.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/serializable_value.h" + +namespace ap::registry { + +struct ClassicDrrPassRegistryItemImpl { + std::string classic_drr_pass_name; + int64_t nice; + axpr::ClassAttrs cls; +}; + +ADT_DEFINE_RC(ClassicDrrPassRegistryItem, ClassicDrrPassRegistryItemImpl); + +} // namespace ap::registry diff --git a/paddle/ap/include/registry/registry.h b/paddle/ap/include/registry/registry.h new file mode 100644 index 00000000000000..b4133d41972f4c --- /dev/null +++ b/paddle/ap/include/registry/registry.h @@ -0,0 +1,53 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/registry/abstract_drr_pass_registry_item.h" +#include "paddle/ap/include/registry/access_topo_drr_pass_registry_item.h" +#include "paddle/ap/include/registry/classic_drr_pass_registry_item.h" + +namespace ap::registry { + +template +using Key2Nice2Items = std::map>>; + +struct RegistryImpl { + Key2Nice2Items abstract_drr_pass_registry_items; + Key2Nice2Items classic_drr_pass_registry_items; + Key2Nice2Items + access_topo_drr_pass_registry_items; + + bool operator==(const RegistryImpl& other) const { return this == &other; } +}; + +ADT_DEFINE_RC(Registry, RegistryImpl); + +} // namespace ap::registry + +namespace ap::axpr { + +template <> +struct TypeImpl : public std::monostate { + using std::monostate::monostate; + const char* Name() const { return "Registry"; } +}; + +} // namespace ap::axpr diff --git a/paddle/ap/include/registry/registry_class.h b/paddle/ap/include/registry/registry_class.h new file mode 100644 index 00000000000000..c3887feaef66fd --- /dev/null +++ b/paddle/ap/include/registry/registry_class.h @@ -0,0 +1,143 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/method.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/registry/registry.h" +#include "paddle/ap/include/registry/registry_singleton.h" + +namespace ap::registry { + +template +adt::Result RegisterAbstractDrrPass(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 3) << adt::errors::TypeError{ + std::string() + "'Registry.abstract_drr_pass()' takes 3 arguments. but " + + std::to_string(args.size()) + " were given."}; + const auto& drr_name_val = args.at(0); + ADT_LET_CONST_REF(drr_name, axpr::TryGetImpl(drr_name_val)) + << adt::errors::TypeError{std::string() + + "argument 1 of 'Registry.abstract_drr_pass()' " + "should be string, but '" + + axpr::GetTypeName(drr_name_val) + + "' were given."}; + const auto& nice_val = args.at(1); + ADT_LET_CONST_REF(nice, axpr::TryGetImpl(nice_val)) + << adt::errors::TypeError{std::string() + + "argument 2 of 'Registry.abstract_drr_pass()' " + "should be int, but '" + + axpr::GetTypeName(nice_val) + "' were given."}; + const auto& cls_val = args.at(2); + ADT_LET_CONST_REF( + type_impl, + axpr::TryGetTypeImpl>>( + cls_val)) + << adt::errors::TypeError{ + std::string() + + "argument 3 of 'Registry.abstract_drr_pass()' should " + "be non-builtin class, but '" + + axpr::GetTypeName(cls_val) + "' were given."}; + AbstractDrrPassRegistryItem item{drr_name, nice, type_impl.class_attrs}; + RegistrySingleton::Add(item); + return adt::Nothing{}; +} + +template +adt::Result RegisterClassicDrrPass(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 3) << adt::errors::TypeError{ + std::string() + "'Registry.classic_drr_pass()' takes 3 arguments. but " + + std::to_string(args.size()) + " were given."}; + const auto& drr_name_val = args.at(0); + ADT_LET_CONST_REF(drr_name, axpr::TryGetImpl(drr_name_val)) + << adt::errors::TypeError{std::string() + + "argument 1 of 'Registry.classic_drr_pass()' " + "should be string, but '" + + axpr::GetTypeName(drr_name_val) + + "' were given."}; + const auto& nice_val = args.at(1); + ADT_LET_CONST_REF(nice, axpr::TryGetImpl(nice_val)) + << adt::errors::TypeError{std::string() + + "argument 2 of 'Registry.classic_drr_pass()' " + "should be int, but '" + + axpr::GetTypeName(nice_val) + "' were given."}; + const auto& cls_val = args.at(2); + ADT_LET_CONST_REF( + type_impl, + axpr::TryGetTypeImpl>>( + cls_val)) + << adt::errors::TypeError{ + std::string() + + "argument 3 of 'Registry.classic_drr_pass()' should " + "be non-builtin class, but '" + + axpr::GetTypeName(cls_val) + "' were given."}; + ClassicDrrPassRegistryItem item{drr_name, nice, type_impl.class_attrs}; + RegistrySingleton::Add(item); + return adt::Nothing{}; +} + +template +adt::Result RegisterAccessTopoDrrPass(const ValueT&, + const std::vector& args) { + ADT_CHECK(args.size() == 3) << adt::errors::TypeError{ + std::string() + + "'Registry.access_topo_drr_pass()' takes 3 arguments. but " + + std::to_string(args.size()) + " were given."}; + const auto& drr_name_val = args.at(0); + ADT_LET_CONST_REF(drr_name, axpr::TryGetImpl(drr_name_val)) + << adt::errors::TypeError{ + std::string() + + "argument 1 of 'Registry.access_topo_drr_pass()' " + "should be string, but '" + + axpr::GetTypeName(drr_name_val) + "' were given."}; + ADT_LET_CONST_REF(pass_tag_name, axpr::TryGetImpl(args.at(1))) + << adt::errors::TypeError{ + std::string() + + "argument 2 of 'Registry.access_topo_drr_pass()' " + "should be int, but '" + + axpr::GetTypeName(args.at(1)) + "' were given."}; + const auto& cls_val = args.at(2); + ADT_LET_CONST_REF( + type_impl, + axpr::TryGetTypeImpl>>( + cls_val)) + << adt::errors::TypeError{ + std::string() + + "argument 3 of 'Registry.access_topo_drr_pass()' should " + "be non-builtin class, but '" + + axpr::GetTypeName(cls_val) + "' were given."}; + AccessTopoDrrPassRegistryItem item{ + drr_name, pass_tag_name, 0, type_impl.class_attrs}; + RegistrySingleton::Add(item); + return adt::Nothing{}; +} + +template +axpr::TypeImpl> MakeRegistryClass() { + static auto cls( + axpr::MakeBuiltinClass("Registry", [&](const auto& DoEach) { + DoEach("abstract_drr_pass", &RegisterAbstractDrrPass); + DoEach("classic_drr_pass", &RegisterClassicDrrPass); + DoEach("access_topo_drr_pass", &RegisterAccessTopoDrrPass); + })); + using Self = Registry; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::registry diff --git a/paddle/ap/include/registry/registry_mgr.h b/paddle/ap/include/registry/registry_mgr.h new file mode 100644 index 00000000000000..f5b11574f45ed0 --- /dev/null +++ b/paddle/ap/include/registry/registry_mgr.h @@ -0,0 +1,103 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/function.h" +#include "paddle/ap/include/axpr/interpreter.h" +#include "paddle/ap/include/axpr/module_mgr.h" +#include "paddle/ap/include/axpr/serializable_value.h" +#include "paddle/ap/include/env/ap_path.h" +#include "paddle/ap/include/fs/fs.h" +#include "paddle/ap/include/registry/builtin_frame_util.h" +#include "paddle/ap/include/registry/value.h" + +namespace ap::registry { + +struct RegistryMgr { + static RegistryMgr* Singleton() { + static RegistryMgr mgr{}; + return &mgr; + } + + adt::Result LoadAllOnce() { + std::unique_lock lock(mutex_); + if (!first_load_result_.has_value()) { + using Ok = adt::Result; + ADT_RETURN_IF_ERR(VisitApEntryFilePath([&](const auto& filepath) -> Ok { + const Ok& cur_result = Load(filepath); + if (!first_load_result_.has_value() && cur_result.HasError()) { + first_load_result_ = cur_result; + } + return adt::Ok{}; + })); + if (!first_load_result_.has_value()) { + first_load_result_ = adt::Ok{}; + } + } + return first_load_result_.value(); + } + + private: + std::optional> first_load_result_; + std::mutex mutex_; + + adt::Result Load(const std::string& filepath) { + ADT_LET_CONST_REF(file_content, GetFileContent(filepath)); + if (file_content.empty()) { + return adt::Ok{}; + } + ADT_LET_CONST_REF(anf_expr, axpr::MakeAnfExprFromJsonString(file_content)); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + const auto& frame = axpr::Frame::Make( + axpr::ModuleMgr::Singleton()->circlable_ref_list(), + std::make_shared>()); + std::vector> args{}; + axpr::Lambda lambda{args, core_expr}; + memory::Guard guard{}; + axpr::Interpreter cps_expr_interpreter( + registry::MakeBuiltinFrameAttrMap(), + guard.circlable_ref_list()); + ADT_RETURN_IF_ERR(cps_expr_interpreter.InterpretModule(frame, lambda)); + return adt::Ok{}; + } + + adt::Result GetFileContent(const std::string& filepath) { + std::ifstream ifs(filepath); + std::string content{std::istreambuf_iterator(ifs), + std::istreambuf_iterator()}; + return content; + } + + template + adt::Result VisitApEntryFilePath(const YieldT& Yield) { + using Ctrl = adt::Result; + ADT_RETURN_IF_ERR(env::VisitEachApPath([&](const auto& dir_path) -> Ctrl { + const std::string file_path = std::string(dir_path) + "/__main__.py.json"; + if (fs::FileExists(file_path)) { + ADT_RETURN_IF_ERR(Yield(file_path)); + } + return adt::Continue{}; + })); + return adt::Ok{}; + } +}; + +} // namespace ap::registry diff --git a/paddle/ap/include/registry/registry_singleton.h b/paddle/ap/include/registry/registry_singleton.h new file mode 100644 index 00000000000000..4f8fb61511fe0c --- /dev/null +++ b/paddle/ap/include/registry/registry_singleton.h @@ -0,0 +1,80 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/registry/registry.h" + +namespace ap::registry { + +struct RegistrySingleton { + static adt::Result Singleton() { + std::unique_lock lock(*SingletonMutex()); + ADT_CHECK(MutOptSingleton()->has_value()) + << adt::errors::NotImplementedError{ + std::string() + "Registry singleton not initialized. "}; + return MutOptSingleton()->value(); + } + + static void Add(const AbstractDrrPassRegistryItem& item) { + auto registry = MutSingleton(); + const auto& abstract_drr_pass_name = item->abstract_drr_pass_name; + int64_t nice = item->nice; + std::unique_lock lock(*SingletonMutex()); + registry->abstract_drr_pass_registry_items[abstract_drr_pass_name][nice] + .emplace_back(item); + } + + static void Add(const ClassicDrrPassRegistryItem& item) { + auto registry = MutSingleton(); + const auto& classic_drr_pass_name = item->classic_drr_pass_name; + int64_t nice = item->nice; + std::unique_lock lock(*SingletonMutex()); + registry->classic_drr_pass_registry_items[classic_drr_pass_name][nice] + .emplace_back(item); + } + + static void Add(const AccessTopoDrrPassRegistryItem& item) { + auto registry = MutSingleton(); + const auto& access_topo_drr_pass_name = item->access_topo_drr_pass_name; + int64_t nice = item->nice; + std::unique_lock lock(*SingletonMutex()); + registry + ->access_topo_drr_pass_registry_items[access_topo_drr_pass_name][nice] + .emplace_back(item); + } + + static std::mutex* SingletonMutex() { + static std::mutex mutex; + return &mutex; + } + + private: + static Registry MutSingleton() { + std::unique_lock lock(*SingletonMutex()); + if (!MutOptSingleton()->has_value()) { + *MutOptSingleton() = Registry{}; + } + return MutOptSingleton()->value(); + } + + static std::optional* MutOptSingleton() { + static std::optional ctx{}; + return &ctx; + } +}; + +} // namespace ap::registry diff --git a/paddle/ap/include/registry/value.h b/paddle/ap/include/registry/value.h new file mode 100644 index 00000000000000..d990a5badf3d4c --- /dev/null +++ b/paddle/ap/include/registry/value.h @@ -0,0 +1,30 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/registry/registry.h" + +namespace ap::registry { + +using axpr::Value; + +using Val = Value; + +using Env = ap::axpr::Environment; + +} // namespace ap::registry diff --git a/paddle/ap/include/reified_drr/drr_node_attr_to_anf_expr_helper.h b/paddle/ap/include/reified_drr/drr_node_attr_to_anf_expr_helper.h new file mode 100644 index 00000000000000..34855f53aa7964 --- /dev/null +++ b/paddle/ap/include/reified_drr/drr_node_attr_to_anf_expr_helper.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/axpr/value.h" + +namespace ap::reified_drr { + +struct DrrNodeAttrToAnfExprHelper { + virtual ~DrrNodeAttrToAnfExprHelper() {} + + virtual adt::Result ConvertTypeToAnfExpr(axpr::LetContext* ctx, + axpr::Value type) = 0; + virtual adt::Result ConvertAttrToAnfExpr(axpr::LetContext* ctx, + axpr::Value attr) = 0; +}; + +} // namespace ap::reified_drr diff --git a/paddle/ap/include/reified_drr/matched_src_ptn_ctx_helper.h b/paddle/ap/include/reified_drr/matched_src_ptn_ctx_helper.h new file mode 100644 index 00000000000000..ecea689eb53d7d --- /dev/null +++ b/paddle/ap/include/reified_drr/matched_src_ptn_ctx_helper.h @@ -0,0 +1,44 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/drr_ctx.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/source_pattern_ctx.h" + +namespace ap::reified_drr { + +struct MatchedSrcPtnCtxHelper { + virtual ~MatchedSrcPtnCtxHelper() {} + + virtual drr::SourcePatternCtx src_ptn_ctx() = 0; + + virtual adt::Result> + MakeInnerMatchedSrcPtnCtxHelper( + const drr::PackedIrOp& packed_ir_op) = 0; + + virtual adt::Result VisitNativeIrOpAttr( + const drr::NativeIrOp& native_ir_op, + const std::function(const std::string& attr_name, + const axpr::Value& attr_val)>& + DoEachAttr) = 0; + + virtual adt::Result GetNativeIrValueType( + const drr::NativeIrValue& native_ir_value) = 0; +}; + +} // namespace ap::reified_drr diff --git a/paddle/ap/include/reified_drr/reified_drr_pass_dump_helper.h b/paddle/ap/include/reified_drr/reified_drr_pass_dump_helper.h new file mode 100644 index 00000000000000..686ed1784996e6 --- /dev/null +++ b/paddle/ap/include/reified_drr/reified_drr_pass_dump_helper.h @@ -0,0 +1,41 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_gen/code_gen_result.h" +#include "paddle/ap/include/code_module/code_module.h" +#include "paddle/ap/include/drr/drr_ctx.h" +#include "paddle/ap/include/reified_drr/drr_node_attr_to_anf_expr_helper.h" +#include "paddle/ap/include/reified_drr/matched_src_ptn_ctx_helper.h" + +namespace ap::reified_drr { + +struct ReifiedDrrPassDumpHelper { + bool DumpEnabled(); + + // Returns reified drr_pass_class lambda + adt::Result Dump( + const drr::DrrCtx& abstract_drr_ctx, + DrrNodeAttrToAnfExprHelper* attr2axpr_helper, + MatchedSrcPtnCtxHelper* src_ptn_ctx_helper, + const std::function>( + const std::string&)>& CodeGenResult4FusedOpName, + int64_t nice) const; +}; + +} // namespace ap::reified_drr diff --git a/paddle/ap/include/reified_drr/reified_res_ptn_axpr_maker.h b/paddle/ap/include/reified_drr/reified_res_ptn_axpr_maker.h new file mode 100644 index 00000000000000..478de65a135de1 --- /dev/null +++ b/paddle/ap/include/reified_drr/reified_res_ptn_axpr_maker.h @@ -0,0 +1,45 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_gen/code_gen_result.h" +#include "paddle/ap/include/drr/result_pattern_ctx.h" + +namespace ap::reified_drr { + +class ReifiedResPtnAxprMaker { + drr::ResultPatternCtx res_ptn_ctx_; + std::function>( + const std::string&)> + CodeGenResult4FusedOpName_; + + public: + ReifiedResPtnAxprMaker( + const drr::ResultPatternCtx& res_ptn_ctx, + const std::function>( + const std::string&)>& CodeGenResult4FusedOpName) + : res_ptn_ctx_(res_ptn_ctx), + CodeGenResult4FusedOpName_(CodeGenResult4FusedOpName) {} + + adt::Result GenAnfExprForResPtnCtxOps(axpr::LetVar* op_pattern_ctx); + + adt::Result GenAnfExprForResPtnCtxOpValueConnections( + axpr::LetVar* op_pattern_ctx, axpr::LetVar* tensor_pattern_ctx); +}; + +} // namespace ap::reified_drr diff --git a/paddle/ap/include/reified_drr/reified_src_ptn_axpr_maker.h b/paddle/ap/include/reified_drr/reified_src_ptn_axpr_maker.h new file mode 100644 index 00000000000000..6f5f8957d66c8d --- /dev/null +++ b/paddle/ap/include/reified_drr/reified_src_ptn_axpr_maker.h @@ -0,0 +1,45 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/drr/source_pattern_ctx.h" +#include "paddle/ap/include/reified_drr/drr_node_attr_to_anf_expr_helper.h" +#include "paddle/ap/include/reified_drr/matched_src_ptn_ctx_helper.h" + +namespace ap::reified_drr { + +class ReifiedSrcPtnAxprMaker { + public: + ReifiedSrcPtnAxprMaker(DrrNodeAttrToAnfExprHelper* anf_expr_helper, + MatchedSrcPtnCtxHelper* matched_src_ptn_ctx_helper) + : anf_expr_helper_(anf_expr_helper), + matched_src_ptn_ctx_helper_(matched_src_ptn_ctx_helper) {} + + adt::Result GenAnfExprForSrcPtnCtxOps(axpr::LetVar* op_pattern_ctx); + + adt::Result GenAnfExprForSrcPtnCtxValues( + axpr::LetVar* tensor_pattern_ctx); + + adt::Result GenAnfExprForSrcPtnCtxOpValueConnections( + axpr::LetVar* op_pattern_ctx, axpr::LetVar* tensor_pattern_ctx); + + private: + DrrNodeAttrToAnfExprHelper* anf_expr_helper_; + MatchedSrcPtnCtxHelper* matched_src_ptn_ctx_helper_; +}; + +} // namespace ap::reified_drr diff --git a/paddle/ap/include/rt_module/arg_value.h b/paddle/ap/include/rt_module/arg_value.h new file mode 100644 index 00000000000000..987c59b0d5f384 --- /dev/null +++ b/paddle/ap/include/rt_module/arg_value.h @@ -0,0 +1,70 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/code_module/arg_type.h" +#include "paddle/ap/include/code_module/data_type.h" + +namespace ap::rt_module { + +using code_module::ArgType; + +using ArgValueImpl = std::variant; + +struct ArgValue : public ArgValueImpl { + using ArgValueImpl::ArgValueImpl; + ADT_DEFINE_VARIANT_METHODS(ArgValueImpl); + + ArgType GetType() const { + return Match([](auto impl) -> ArgType { return impl.GetType(); }); + } + + template + adt::Result CastTo() const { + return Match([](const auto& impl) -> adt::Result { return impl; }); + } + + template + adt::Result TryGetValue() const { + if constexpr (std::is_pointer_v) { + const auto& pointer_value = + this->template TryGet(); + ADT_RETURN_IF_ERR(pointer_value); + return pointer_value.GetOkValue().template TryGet(); + } else { + const auto& data_value = this->template TryGet(); + ADT_RETURN_IF_ERR(data_value); + return data_value.GetOkValue().template TryGet(); + } + } +}; + +template +Result CastToArgValue(const ValueT& value) { + return value.Match( + [&](const ap::axpr::DataValue& impl) -> Result { return impl; }, + [&](const ap::axpr::PointerValue& impl) -> Result { + return impl; + }, + [&](const auto&) -> Result { + return TypeError{std::string() + + "CastToArgValue failed. expected types: " + "(DataValue, PointerValue), actual type: " + + axpr::GetTypeName(value)}; + }); +} + +} // namespace ap::rt_module diff --git a/paddle/ap/include/rt_module/dl_function.h b/paddle/ap/include/rt_module/dl_function.h new file mode 100644 index 00000000000000..3c12d393c5c6cd --- /dev/null +++ b/paddle/ap/include/rt_module/dl_function.h @@ -0,0 +1,53 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" + +namespace ap::rt_module { + +class DlHandle; + +// dynamic link function +class DlFunction { + public: + DlFunction(const std::shared_ptr& dl_handle, + void* func, + void (*wrapper)(void* ret, void* func, void** args)) + : dl_handle_(dl_handle), func_(func), api_wrapper_(wrapper) {} + + DlFunction(const DlFunction&) = default; + DlFunction(DlFunction&&) = default; + + bool operator==(const DlFunction& other) const { + // It's correct to ignore dl_handle_ + return this->func_ == other.func_ && + this->api_wrapper_ == other.api_wrapper_; + } + + adt::Result Apply(void* ret, void** args) const { + ADT_LET_CONST_REF(dl_handle_guard, adt::WeakPtrLock(dl_handle_)); + api_wrapper_(ret, func_, args); + return adt::Ok{}; + } + + private: + std::weak_ptr dl_handle_; + void* func_; + void (*api_wrapper_)(void* ret, void* func, void** args); +}; + +} // namespace ap::rt_module diff --git a/paddle/ap/include/rt_module/dl_handle.h b/paddle/ap/include/rt_module/dl_handle.h new file mode 100644 index 00000000000000..fd793704f1aa72 --- /dev/null +++ b/paddle/ap/include/rt_module/dl_handle.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/rt_module/dl_function.h" + +namespace ap::rt_module { + +// dynamic link handle +class DlHandle : public std::enable_shared_from_this { + public: + virtual ~DlHandle() = default; + + virtual adt::Result DlSym(const std::string& name) const = 0; + + protected: + DlHandle() = default; +}; + +} // namespace ap::rt_module diff --git a/paddle/ap/include/rt_module/function.h b/paddle/ap/include/rt_module/function.h new file mode 100644 index 00000000000000..8241025fb065ed --- /dev/null +++ b/paddle/ap/include/rt_module/function.h @@ -0,0 +1,35 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/code_module/func_declare.h" +#include "paddle/ap/include/rt_module/dl_function.h" + +namespace ap::rt_module { + +struct FunctionImpl { + code_module::FuncDeclare func_declare; + DlFunction dl_function; + + bool operator==(const FunctionImpl& other) const { + return this->func_declare == other.func_declare && + this->dl_function == other.dl_function; + } +}; + +ADT_DEFINE_RC(Function, FunctionImpl); + +} // namespace ap::rt_module diff --git a/paddle/ap/include/rt_module/function_helper.h b/paddle/ap/include/rt_module/function_helper.h new file mode 100644 index 00000000000000..4255d70909830f --- /dev/null +++ b/paddle/ap/include/rt_module/function_helper.h @@ -0,0 +1,122 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/rt_module/arg_value.h" +#include "paddle/ap/include/rt_module/function.h" + +namespace ap::rt_module { + +struct FunctionHelper { + adt::Result Apply(const Function& function, + const std::vector& args) { + const auto& func_declare = function->func_declare; + ADT_LET_CONST_REF(ret_val, GetDefaultVal(func_declare->ret_type)); + ADT_LET_CONST_REF(ret_ptr, GetAddrAsVoidPtr(ret_val)); + void* ret = ret_ptr; + std::vector void_ptr_args; + void_ptr_args.reserve(args.size()); + ADT_CHECK(func_declare->arg_types->size() == args.size()) + << adt::errors::TypeError{ + std::string() + func_declare->func_id + "() takes " + + std::to_string(func_declare->arg_types->size()) + + " arguments, but " + std::to_string(args.size()) + + " were given"}; + for (int i = 0; i < args.size(); ++i) { + const auto& arg_axpr_value = args.at(i); + { + // check arg type + const auto& arg_type = func_declare->arg_types->at(i); + ADT_LET_CONST_REF(arg_value, + CastToArgValue(arg_axpr_value)); + ADT_CHECK(arg_value.GetType() == arg_type) << adt::errors::TypeError{ + std::string() + "the argument " + std::to_string(i) + " of " + + func_declare->func_id + "() should be a " + arg_type.Name() + + "(not " + arg_value.GetType().Name() + ")"}; + } + ADT_LET_CONST_REF(ptr, GetAddrAsVoidPtr(arg_axpr_value)); + void_ptr_args.emplace_back(ptr); + } + ADT_RETURN_IF_ERR(function->dl_function.Apply(ret, void_ptr_args.data())); + return ret_val; + } + + adt::Result GetDefaultVal(const ArgType& arg_type) { + return arg_type.Match( + [&](const axpr::DataType& data_type) -> adt::Result { + ADT_LET_CONST_REF(data_value, GetDataTypeDefaultVal(data_type)); + return data_value; + }, + [&](const axpr::PointerType& pointer_type) -> adt::Result { + ADT_LET_CONST_REF(pointer_value, + GetPointerTypeDefaultVal(pointer_type)); + return pointer_value; + }); + } + + adt::Result GetDataTypeDefaultVal( + const axpr::DataType& data_type) { + return data_type.Match( + [&](const auto& impl) -> adt::Result { + using T = typename std::decay_t::type; + T val{}; + return axpr::DataValue{val}; + }); + } + + adt::Result GetPointerTypeDefaultVal( + const axpr::PointerType& pointer_type) { + return pointer_type.Match( + [&](const auto& impl) -> adt::Result { + using T = typename std::decay_t::type; + T ptr = nullptr; + return axpr::PointerValue{ptr}; + }); + } + + adt::Result GetAddrAsVoidPtr(const axpr::Value& arg_value) { + return arg_value.Match( + [&](const axpr::DataValue& data) -> adt::Result { + return data.Match( + [&](const adt::Undefined&) -> adt::Result { + static_assert( + axpr::DataValue::IsMyAlternative(), ""); + // adt::Undefined represents cpp void type, because we cannot + // define a void typed value. + return nullptr; + }, + [&](const auto& impl) -> adt::Result { + using T = std::decay_t; + return const_cast(&impl); + }); + }, + [&](const axpr::PointerValue& ptr) -> adt::Result { + return ptr.Match([&](const auto& impl) -> adt::Result { + using T = std::decay_t; + return const_cast(&impl); + }); + }, + [&](const auto&) -> adt::Result { + return adt::errors::TypeError{ + std::string() + + "only DataValue or PointerValue are supported as so function " + "arguments (not " + + axpr::GetTypeName(arg_value) + ")."}; + }); + } +}; + +} // namespace ap::rt_module diff --git a/paddle/ap/include/rt_module/function_method_class.h b/paddle/ap/include/rt_module/function_method_class.h new file mode 100644 index 00000000000000..e9630d434a7b55 --- /dev/null +++ b/paddle/ap/include/rt_module/function_method_class.h @@ -0,0 +1,46 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/builtin_class_instance.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/type.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/rt_module/function.h" +#include "paddle/ap/include/rt_module/function_helper.h" + +namespace ap::rt_module { + +struct FunctionMethodClass { + using Self = Function; + + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return FunctionHelper{}.Apply(self, args); + } +}; + +inline axpr::TypeImpl> +GetSoFunctionClass() { + static auto cls(axpr::MakeBuiltinClass( + "so_function", [&](const auto& DoEach) { + DoEach("__call__", &FunctionMethodClass::Call); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::rt_module diff --git a/paddle/ap/include/rt_module/module.h b/paddle/ap/include/rt_module/module.h new file mode 100644 index 00000000000000..05b61ef0f77fa4 --- /dev/null +++ b/paddle/ap/include/rt_module/module.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/rt_module/function.h" + +namespace ap::rt_module { + +class Module { + public: + virtual ~Module() = default; + + virtual adt::Result Get(const std::string& func_name) const = 0; + + protected: + Module() = default; +}; + +} // namespace ap::rt_module diff --git a/paddle/ap/include/rt_module/naive_dl_handler.h b/paddle/ap/include/rt_module/naive_dl_handler.h new file mode 100644 index 00000000000000..af08d86ea0ab5a --- /dev/null +++ b/paddle/ap/include/rt_module/naive_dl_handler.h @@ -0,0 +1,85 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/rt_module/dl_function.h" +#include "paddle/ap/include/rt_module/dl_handle.h" + +namespace ap::rt_module { + +class NaiveDlHandle : public DlHandle { + public: + NaiveDlHandle(const NaiveDlHandle&) = default; + NaiveDlHandle(NaiveDlHandle&&) = default; + ~NaiveDlHandle() { + dlclose(main_handle_); + dlclose(api_wrappers_handle_); + } + + adt::Result DlSym(const std::string& name) const override { + void* function = dlsym(main_handle_, name.c_str()); + void* api_wrapper = dlsym(api_wrappers_handle_, name.c_str()); + ADT_CHECK(function != nullptr) + << adt::errors::ValueError{std::string() + "main so '" + name + + "' not found in '" + main_so_path_ + "'"}; + ADT_CHECK(api_wrapper != nullptr) << adt::errors::ValueError{ + std::string() + "api_wrapper so '" + name + "' not found in '" + + api_wrappers_so_path_ + "'"}; + std::shared_ptr self = shared_from_this(); + ADT_CHECK(self != nullptr); + using ApiWrapperT = void (*)(void* ret, void* func, void** args); + DlFunction ret{self, function, reinterpret_cast(api_wrapper)}; + return ret; + } + + static adt::Result> DlOpen( + const std::string& main_so_path, + const std::string& api_wrappers_so_path) { + void* main_handle = dlopen(main_so_path.c_str(), RTLD_LAZY); + if (!main_handle) { + return adt::errors::RuntimeError{ + std::string() + "dlopen failed. error message: " + dlerror() + + ". path: " + main_so_path}; + } + void* api_wrappers_handle = dlopen(api_wrappers_so_path.c_str(), RTLD_LAZY); + if (!api_wrappers_handle) { + dlclose(main_handle); + return adt::errors::RuntimeError{ + std::string() + "dlopen failed. error message: " + dlerror() + + ". path: " + api_wrappers_so_path}; + } + return std::shared_ptr(new NaiveDlHandle( + main_handle, api_wrappers_handle, main_so_path, api_wrappers_so_path)); + } + + private: + NaiveDlHandle(void* main_handle, + void* api_wrappers_handle, + const std::string& main_so_path, + const std::string& api_wrappers_so_path) + : main_handle_(main_handle), + api_wrappers_handle_(api_wrappers_handle), + main_so_path_(main_so_path), + api_wrappers_so_path_(api_wrappers_so_path) {} + + void* main_handle_; + void* api_wrappers_handle_; + std::string main_so_path_; + std::string api_wrappers_so_path_; +}; + +} // namespace ap::rt_module diff --git a/paddle/ap/include/rt_module/naive_module.h b/paddle/ap/include/rt_module/naive_module.h new file mode 100644 index 00000000000000..f2f0d3d0e0831f --- /dev/null +++ b/paddle/ap/include/rt_module/naive_module.h @@ -0,0 +1,63 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/code_module/func_declare.h" +#include "paddle/ap/include/rt_module/dl_handle.h" +#include "paddle/ap/include/rt_module/module.h" + +namespace ap::rt_module { + +class NaiveModule : public Module { + public: + NaiveModule(const NaiveModule&) = delete; + NaiveModule(NaiveModule&&) = delete; + + adt::Result Get(const std::string& func_name) const override { + const auto& iter = name2func_declare_.find(func_name); + ADT_CHECK(iter != name2func_declare_.end()) << adt::errors::KeyError{ + std::string() + "function " + func_name + " is not declared"}; + const auto& func_declare = iter->second; + ADT_LET_CONST_REF(dl_function, dl_handle_->DlSym(func_name)); + return Function{func_declare, dl_function}; + } + + static adt::Result> Make( + const std::vector& func_declares, + const std::shared_ptr& dl_handle) { + std::map name2func_declare{}; + for (const auto& func_declare : func_declares) { + ADT_CHECK( + name2func_declare.emplace(func_declare->func_id, func_declare).second) + << adt::errors::KeyError{ + std::string() + + "duplicated function name: " + func_declare->func_id}; + } + std::shared_ptr m( + new NaiveModule{name2func_declare, dl_handle}); + return m; + } + + private: + NaiveModule( + const std::map& name2func_declare, + const std::shared_ptr& dl_handle) + : name2func_declare_(name2func_declare), dl_handle_(dl_handle) {} + std::map name2func_declare_; + std::shared_ptr dl_handle_; +}; + +} // namespace ap::rt_module diff --git a/paddle/ap/include/rt_module/naive_module_maker.h b/paddle/ap/include/rt_module/naive_module_maker.h new file mode 100644 index 00000000000000..91231ae0f79002 --- /dev/null +++ b/paddle/ap/include/rt_module/naive_module_maker.h @@ -0,0 +1,151 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/code_module/api_wrapper_project_maker.h" +#include "paddle/ap/include/code_module/code_module.h" +#include "paddle/ap/include/code_module/project_compile_helper.h" +#include "paddle/ap/include/rt_module/naive_dl_handler.h" +#include "paddle/ap/include/rt_module/naive_module.h" + +namespace ap::rt_module { + +struct NaiveModuleMaker { + explicit NaiveModuleMaker(const std::string& workspace_dir_val) + : workspace_dir(workspace_dir_val) {} + + adt::Result> Make( + const code_module::CodeModule& code_module, + const std::function& + Serialize) const { + using RetT = adt::Result>; + return code_module->source_code.Match( + [&](const code_module::Project&) -> RetT { + return MakeByProject(code_module, Serialize); + }, + [&](const code_module::Package&) -> RetT { + return MakeByPackage(code_module, Serialize); + }); + } + + adt::Result> MakeByProject( + const code_module::CodeModule& code_module, + const std::function& + Serialize) const { + const auto& func_declares = code_module->func_declares.vector(); + ADT_LET_CONST_REF( + api_wrapper_project, + code_module::ApiWrapperProjectMaker{}.Make(func_declares)); + ADT_LET_CONST_REF(main_project, GetMainProject(code_module)); + code_module::ProjectCompileHelper api_wrapper_compile_helper( + GetApiWrapperProjectDir(), api_wrapper_project); + code_module::ProjectCompileHelper main_compile_helper(GetMainProjectDir(), + main_project); + const auto& serialized_project = Serialize(code_module); + if (FileExists(GetSerializedProjectFilePath())) { + ADT_LET_CONST_REF(dumped, ReadSerializedProject()); + ADT_CHECK(dumped == serialized_project); + } else { + ADT_RETURN_IF_ERR(api_wrapper_compile_helper.DumpNestedFilesToFs()); + ADT_RETURN_IF_ERR(main_compile_helper.DumpNestedFilesToFs()); + ADT_RETURN_IF_ERR(api_wrapper_compile_helper.Compile()); + ADT_RETURN_IF_ERR(main_compile_helper.Compile()); + ADT_RETURN_IF_ERR(WriteSerializedProject(serialized_project)); + } + std::string api_wrapper_so_path = api_wrapper_compile_helper.GetSoPath(); + std::string main_so_path = main_compile_helper.GetSoPath(); + ADT_LET_CONST_REF(dl_handler, + NaiveDlHandle::DlOpen(main_so_path, api_wrapper_so_path)); + return NaiveModule::Make(func_declares, dl_handler); + } + + adt::Result> MakeByPackage( + const code_module::CodeModule& code_module, + const std::function& + Serialize) const { + const auto& func_declares = code_module->func_declares.vector(); + ADT_LET_CONST_REF(package, GetPackage(code_module)); + std::string api_wrapper_so_path = + GetPackageDir() + "/" + package->api_wrapper_so_relative_path; + ADT_CHECK(FileExists(api_wrapper_so_path)) << adt::errors::TypeError{ + std::string() + + "FileExists(api_wrapper_so_path) failed. api_wrapper_so_path: " + + api_wrapper_so_path}; + std::string main_so_path = + GetPackageDir() + "/" + package->main_so_relative_path; + ADT_CHECK(FileExists(main_so_path)) << adt::errors::TypeError{ + std::string() + + "FileExists(main_so_path) failed. main_so_path: " + main_so_path}; + ADT_LET_CONST_REF(dl_handler, + NaiveDlHandle::DlOpen(main_so_path, api_wrapper_so_path)); + return NaiveModule::Make(func_declares, dl_handler); + } + + adt::Result ReadSerializedProject() const { + std::ifstream ifs(GetSerializedProjectFilePath()); + ADT_CHECK(ifs.is_open()); + std::stringstream buffer; + buffer << ifs.rdbuf(); + return buffer.str(); + } + + bool FileExists(const std::string& filepath) const { + std::fstream fp; + fp.open(filepath, std::fstream::in); + if (fp.is_open()) { + fp.close(); + return true; + } else { + return false; + } + } + + adt::Result WriteSerializedProject( + const std::string& serialized_project) const { + std::ofstream ofs(GetSerializedProjectFilePath()); + ADT_CHECK(ofs.is_open()); + ofs << serialized_project; + ofs.close(); + return adt::Ok{}; + } + + adt::Result GetMainProject( + const code_module::CodeModule& code_module) const { + return code_module->source_code.template TryGet(); + } + + adt::Result GetPackage( + const code_module::CodeModule& code_module) const { + return code_module->source_code.template TryGet(); + } + + std::string GetApiWrapperProjectDir() const { + return workspace_dir + "/api_wrapper/"; + } + + std::string GetPackageDir() const { return workspace_dir; } + + std::string GetMainProjectDir() const { return workspace_dir + "/main/"; } + + std::string GetSerializedProjectFilePath() const { + return workspace_dir + "/serialized_project.json"; + } + + private: + std::string workspace_dir; +}; + +} // namespace ap::rt_module diff --git a/paddle/ap/src/axpr/anf_expr.cc b/paddle/ap/src/axpr/anf_expr.cc new file mode 100644 index 00000000000000..a2129c58c3e41c --- /dev/null +++ b/paddle/ap/src/axpr/anf_expr.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/axpr/anf_expr.h" +#include "paddle/ap/include/axpr/anf_expr_builder.h" + +#include +#include +#include +#include "nlohmann/json.hpp" + +namespace ap::axpr { + +using adt::Result; + +using Json = nlohmann::json; + +Json ConvertAnfExprToJson(const AnfExpr& anf_expr); + +Json ConvertAtomicAnfExprToJson(const Atomic& atomic_expr) { + return atomic_expr.Match( + [&](const tVar& var) { + Json j = var.value(); + return j; + }, + [&](adt::Nothing) { + Json j; + return j; + }, + [&](bool c) { + Json j = c; + return j; + }, + [&](int64_t c) { + Json j = c; + return j; + }, + [&](double c) { + Json j = c; + return j; + }, + [&](const std::string& c) { + Json j; + j[AnfExpr::kString()] = c; + return j; + }, + [&](const Lambda& lambda) { + Json j = Json::array(); + j.push_back(AnfExpr::kLambda()); + j.push_back([&] { + Json j_args = Json::array(); + for (const auto& arg : lambda->args) { + j_args.push_back(arg.value()); + } + return j_args; + }()); + j.push_back(ConvertAnfExprToJson(lambda->body)); + return j; + }); +} + +Json ConvertCombinedAnfExprToJson(const Combined& combined_expr) { + return combined_expr.Match( + [&](const Call& call_expr) { + Json j; + j.push_back(ConvertAtomicAnfExprToJson(call_expr->func)); + for (const auto& arg : call_expr->args) { + j.push_back(ConvertAtomicAnfExprToJson(arg)); + } + return j; + }, + [&](const If& if_expr) { + Json j; + j.push_back(AnfExpr::kIf()); + j.push_back(ConvertAtomicAnfExprToJson(if_expr->cond)); + j.push_back(ConvertAnfExprToJson(if_expr->true_expr)); + j.push_back(ConvertAnfExprToJson(if_expr->false_expr)); + return j; + }); +} + +Json ConvertBindingAnfExprToJson(const Bind& binding_expr) { + Json j = Json::array(); + j.push_back(binding_expr.var.value()); + j.push_back(ConvertCombinedAnfExprToJson(binding_expr.val)); + return j; +} + +Json ConvertLetAnfExprToJson(const Let& let_expr) { + Json j; + j.push_back(AnfExpr::kLet()); + Json j_array = Json::array(); + for (const auto& binding : let_expr->bindings) { + j_array.push_back(ConvertBindingAnfExprToJson(binding)); + } + j.push_back(j_array); + j.push_back(ConvertAnfExprToJson(let_expr->body)); + return j; +} + +Json ConvertAnfExprToJson(const AnfExpr& anf_expr) { + return anf_expr.Match( + [&](const Atomic& atomic_expr) { + return ConvertAtomicAnfExprToJson(atomic_expr); + }, + [&](const Combined& combined_expr) { + return ConvertCombinedAnfExprToJson(combined_expr); + }, + [&](const Let& let_expr) { + return ConvertLetAnfExprToJson(let_expr); + }); +} + +std::string AnfExpr::DumpToJsonString() const { + return ConvertAnfExprToJson(*this).dump(); +} + +std::string AnfExpr::DumpToJsonString(int indent) const { + return ConvertAnfExprToJson(*this).dump(indent); +} + +} // namespace ap::axpr diff --git a/paddle/ap/src/axpr/builtin_functions.cc b/paddle/ap/src/axpr/builtin_functions.cc new file mode 100644 index 00000000000000..70e629d91ce6c8 --- /dev/null +++ b/paddle/ap/src/axpr/builtin_functions.cc @@ -0,0 +1,505 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/axpr/builtin_functions.h" +#include +#include +#include "paddle/ap/include/axpr/abstract_list.h" +#include "paddle/ap/include/axpr/bool_helper.h" +#include "paddle/ap/include/axpr/bool_int_double_helper.h" +#include "paddle/ap/include/axpr/builtin_high_order_func_type.h" +#include "paddle/ap/include/axpr/callable_helper.h" +#include "paddle/ap/include/axpr/data_value_util.h" +#include "paddle/ap/include/axpr/exception_method_class.h" +#include "paddle/ap/include/axpr/method_class.h" +#include "paddle/ap/include/axpr/string_util.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/axpr/value_method_class.h" + +namespace ap::axpr { + +Result BuiltinIdentity(const axpr::Value&, + const std::vector& args) { + if (args.size() != 1) { + return TypeError{std::string(kBuiltinIdentity()) + + "takes 1 argument, but " + std::to_string(args.size()) + + "were given."}; + } + return args.at(0); +} + +Result BuiltinNot(const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(bool_val, BoolHelper{}.ConvertToBool(args.at(0))); + return !bool_val; +} + +Result Raise(const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(exception, args.at(0).template CastTo()); + return exception.value(); +} + +Result BuiltinList(const axpr::Value&, + const std::vector& args) { + adt::List l; + for (const auto& arg : args) { + const auto& arg_ret = arg.Match( + [&](const Starred& starred) -> Result { + ADT_LET_CONST_REF( + sublist, starred->obj.template TryGet>()); + for (const auto& elt : *sublist) { + l->emplace_back(elt); + } + return adt::Ok{}; + }, + [&](const auto&) -> Result { + l->emplace_back(arg); + return adt::Ok{}; + }); + ADT_RETURN_IF_ERR(arg_ret); + } + return axpr::Value{l}; +} + +Result BuiltinHalt(const axpr::Value&, + const std::vector& args) { + return RuntimeError{"Dead code. Halt function should never be touched."}; +} + +adt::Result Print(InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args) { + std::ostringstream ss; + int i = 0; + for (const auto& obj : args) { + if (i++ > 0) { + ss << " "; + } + const auto& func = MethodClass::ToString(obj); + using Ok = adt::Result; + ADT_RETURN_IF_ERR(func.Match( + [&](const adt::Nothing&) -> Ok { + return adt::errors::TypeError{std::string() + GetTypeName(obj) + + " class has no ToString method"}; + }, + [&](adt::Result (*unary_func)(const axpr::Value&)) -> Ok { + ADT_LET_CONST_REF(str_val, unary_func(obj)); + ADT_LET_CONST_REF(str, str_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + "'" + axpr::GetTypeName(obj) + + ".__builtin_ToString__ should return a 'str' but '" + + axpr::GetTypeName(str_val) + "' were returned."}; + ss << str; + return adt::Ok{}; + }, + [&](adt::Result (*unary_func)( + InterpreterBase*, const axpr::Value&)) -> Ok { + ADT_LET_CONST_REF(str_val, unary_func(interpreter, obj)); + ADT_LET_CONST_REF(str, str_val.template TryGet()) + << adt::errors::TypeError{ + std::string() + "'" + axpr::GetTypeName(obj) + + ".__builtin_ToString__ should return a 'str' but '" + + axpr::GetTypeName(str_val) + "' were returned."}; + ss << str; + return adt::Ok{}; + })); + } + LOG(ERROR) << "Print\n" << ss.str(); + return adt::Nothing{}; +} + +adt::Result ReplaceOrTrimLeftComma( + const axpr::Value&, const std::vector& args) { + ADT_CHECK(args.size() == 3) << adt::errors::TypeError{ + std::string() + "'replace_or_trim_left_comma' takes 3 arguments but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(self, args.at(0).template TryGet()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of 'replace_or_trim_left_comma' should be a str " + "(not '" + + axpr::GetTypeName(args.at(0)) + "')."}; + ADT_LET_CONST_REF(pattern, args.at(1).template TryGet()) + << adt::errors::TypeError{ + std::string() + + "the argument 2 of 'replace_or_trim_left_comma' should be a str " + "(not '" + + axpr::GetTypeName(args.at(1)) + "')."}; + ADT_LET_CONST_REF(replacement, args.at(2).template TryGet()) + << adt::errors::TypeError{ + std::string() + + "the argument 3 of 'replace_or_trim_left_comma' should be a str " + "(not '" + + axpr::GetTypeName(args.at(2)) + "')."}; + std::size_t pattern_pos = self.find(pattern); + if (pattern_pos == std::string::npos) { + return self; + } + auto EquivalentComma = + [](const std::string& self, std::size_t start, std::size_t end) { + if (start == std::string::npos) { + return false; + } + if (start < 0) { + return false; + } + if (start >= self.size()) { + return false; + } + if (end == std::string::npos) { + return false; + } + if (end < 0) { + return false; + } + if (end >= self.size()) { + return false; + } + if (start >= end) { + return false; + } + if (self[start] != ',') { + return false; + } + for (int i = start + 1; i < end; ++i) { + char ch = self[i]; + if (ch == ' ') { + continue; + } + if (ch == '\r') { + continue; + } + if (ch == '\n') { + continue; + } + if (ch == '\t') { + continue; + } + return false; + } + return true; + }; + if (replacement.empty()) { + std::string str = self; + while (true) { + std::size_t pattern_pos = self.find(pattern); + if (pattern_pos == std::string::npos) { + break; + } + std::size_t comma_pos = str.rfind(',', pattern_pos); + if (!EquivalentComma(str, comma_pos, pattern_pos)) { + break; + } + str = str.replace(comma_pos, pattern_pos + pattern.size(), ""); + } + return str; + } else { + std::string str = self; + while (true) { + std::size_t pos = str.find(pattern); + if (pos == std::string::npos) { + break; + } + str = str.replace(pos, pattern.size(), replacement); + } + return str; + } +} + +adt::Result MakeRange(const axpr::Value&, + const std::vector& args) { + std::optional start; + std::optional end; + if (args.size() == 1) { + start = 0; + ADT_LET_CONST_REF(arg0, args.at(0).template TryGet()) + << adt::errors::TypeError{ + std::string() + "'range' takes int argument but " + + axpr::GetTypeName(args.at(0)) + " were given."}; + end = arg0; + } else if (args.size() == 2) { + ADT_LET_CONST_REF(arg0, args.at(0).template TryGet()) + << adt::errors::TypeError{ + std::string() + "'range' takes int argument but " + + axpr::GetTypeName(args.at(0)) + " were given."}; + ADT_LET_CONST_REF(arg1, args.at(1).template TryGet()) + << adt::errors::TypeError{ + std::string() + "'range' takes int argument but " + + axpr::GetTypeName(args.at(1)) + " were given."}; + start = arg0; + end = arg1; + } else { + ADT_CHECK(false) << adt::errors::TypeError{ + std::string() + "'range' takes 1 or 2 arguments but " + + std::to_string(args.size()) + " were given."}; + } + ADT_CHECK(start.has_value()); + ADT_CHECK(end.has_value()); + adt::List ret; + ret->reserve((start.value() > end.value() ? 0 : end.value() - start.value())); + for (int64_t i = start.value(); i < end.value(); ++i) { + ret->emplace_back(i); + } + return ret; +} + +Result Map(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 2) + << adt::errors::TypeError{std::string() + "map() takes 2 arguments but " + + std::to_string(args.size()) + " were given."}; + + ADT_LET_CONST_REF(lst, axpr::AbstractList::CastFrom(args.at(1))); + ADT_LET_CONST_REF(lst_size, lst.size()); + adt::List ret; + ret->reserve(lst_size); + const auto& f = args.at(0); + ADT_RETURN_IF_ERR( + lst.Visit([&](const auto& elt) -> adt::Result { + ADT_LET_CONST_REF( + converted_elt, + interpreter->InterpretCall(f, std::vector{elt})); + ret->emplace_back(converted_elt); + return adt::Continue{}; + })); + return ret; +} + +Result Apply(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string() + "apply() takes 2 arguments but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(lst, axpr::AbstractList::CastFrom(args.at(1))); + ADT_LET_CONST_REF(lst_size, lst.size()); + std::vector func_args; + func_args.reserve(lst_size); + ADT_RETURN_IF_ERR( + lst.Visit([&](const auto& elt) -> adt::Result { + func_args.push_back(elt); + return adt::Continue{}; + })); + return interpreter->InterpretCall(args.at(0), func_args); +} + +Result Length(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) + << adt::errors::TypeError{std::string() + "len() takes 1 arguments but " + + std::to_string(args.size()) + " were given."}; + axpr::Value len_symbol{builtin_symbol::Symbol{builtin_symbol::Length{}}}; + return interpreter->InterpretCall(len_symbol, args); +} + +Result FlatMap(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string() + "flat_map() takes 2 arguments but " + + std::to_string(args.size()) + " were given."}; + + ADT_LET_CONST_REF(lst, axpr::AbstractList::CastFrom(args.at(1))); + ADT_LET_CONST_REF(lst_size, lst.size()); + adt::List ret; + ret->reserve(lst_size); + auto Collect = [&](const auto& sub_elt) -> adt::Result { + ret->emplace_back(sub_elt); + return adt::Continue{}; + }; + const auto& f = args.at(0); + ADT_RETURN_IF_ERR( + lst.Visit([&](const auto& elt) -> adt::Result { + ADT_LET_CONST_REF( + converted_elt, + interpreter->InterpretCall(f, std::vector{elt})); + ADT_LET_CONST_REF(a_list, + AbstractList::CastFrom(converted_elt)) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of flat_map() should be a function " + "returning a list/SerializableList/MutableList (not a " + + axpr::GetTypeName(converted_elt) + ")"}; + ADT_RETURN_IF_ERR(a_list.Visit(Collect)); + return adt::Continue{}; + })); + return ret; +} + +Result Filter(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string() + "filter() takes 2 arguments but " + + std::to_string(args.size()) + " were given."}; + + ADT_LET_CONST_REF(lst, axpr::AbstractList::CastFrom(args.at(1))); + ADT_LET_CONST_REF(lst_size, lst.size()); + adt::List ret; + ret->reserve(lst_size); + const auto& f = args.at(0); + ADT_RETURN_IF_ERR( + lst.Visit([&](const auto& elt) -> adt::Result { + ADT_LET_CONST_REF( + filter_result, + interpreter->InterpretCall(f, std::vector{elt})); + ADT_LET_CONST_REF(is_true, BoolHelper{}.ConvertToBool(filter_result)); + if (is_true) { + ret->emplace_back(elt); + } + return adt::Continue{}; + })); + return ret; +} + +Result Zip(const axpr::Value&, + const std::vector& args) { + std::optional size; + for (const auto& arg : args) { + ADT_LET_CONST_REF(lst, axpr::AbstractList::CastFrom(arg)) + << adt::errors::TypeError{std::string() + + "the argument of 'zip' should be list."}; + ADT_LET_CONST_REF(lst_size, lst.size()); + if (size.has_value()) { + ADT_CHECK(size.value() == lst_size) << adt::errors::TypeError{ + std::string() + "the arguments of 'zip' should be the same size."}; + } else { + size = lst_size; + } + } + adt::List ret; + ret->reserve(size.value()); + for (int i = 0; i < size.value(); ++i) { + adt::List tuple; + tuple->reserve(args.size()); + for (const auto& arg : args) { + ADT_LET_CONST_REF(lst, axpr::AbstractList::CastFrom(arg)); + ADT_LET_CONST_REF(elt, lst.at(i)); + tuple->emplace_back(elt); + } + ret->emplace_back(tuple); + } + return ret; +} + +Result Reduce(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 2 || args.size() == 3) << adt::errors::TypeError{ + std::string() + "'reduce' takes 2 or 3 arguments but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(lst, axpr::AbstractList::CastFrom(args.at(1))); + std::optional init; + std::optional start; + ADT_LET_CONST_REF(lst_size, lst.size()); + if (lst_size > 0) { + ADT_LET_CONST_REF(init_val, lst.at(0)); + init = init_val; + start = 1; + } else { + ADT_CHECK(args.size() == 3) << adt::errors::TypeError{ + std::string() + "reduce() of empty sequence with no initial value"}; + init = args.at(2); + start = 0; + } + ADT_CHECK(init.has_value()); + ADT_CHECK(start.has_value()); + axpr::Value ret{init.value()}; + const auto& f = args.at(0); + for (int i = start.value(); i < lst_size; ++i) { + ADT_LET_CONST_REF(elt, lst.at(i)); + ADT_LET_CONST_REF( + cur_reduced, + interpreter->InterpretCall(f, std::vector{elt, ret})); + ret = cur_reduced; + } + return ret; +} + +Result Max(const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 2) + << adt::errors::TypeError{std::string() + "max() takes 2 arguments but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(lhs, BoolIntDouble::CastFrom(args.at(0))) + << adt::errors::TypeError{std::string() + + "the argument 1 of max() should be 'bool', " + "'int' or 'float' (not '" + + axpr::GetTypeName(args.at(0)) + "')."}; + ADT_LET_CONST_REF(rhs, BoolIntDouble::CastFrom(args.at(1))) + << adt::errors::TypeError{std::string() + + "the argument 1 of max() should be 'bool', " + "'int' or 'float' (not '" + + axpr::GetTypeName(args.at(0)) + "')."}; + BoolIntDoubleHelper helper{}; + ADT_LET_CONST_REF(cmp_ret, + helper.template BinaryFunc(lhs, rhs)); + ADT_LET_CONST_REF(cmp, cmp_ret.template TryGet()); + return cmp ? args.at(0) : args.at(1); +} + +Result Min(const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 2) + << adt::errors::TypeError{std::string() + "min() takes 2 arguments but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(lhs, BoolIntDouble::CastFrom(args.at(0))) + << adt::errors::TypeError{std::string() + + "the argument 1 of min() should be 'bool', " + "'int' or 'float' (not '" + + axpr::GetTypeName(args.at(0)) + "')."}; + ADT_LET_CONST_REF(rhs, BoolIntDouble::CastFrom(args.at(1))) + << adt::errors::TypeError{std::string() + + "the argument 1 of min() should be 'bool', " + "'int' or 'float' (not '" + + axpr::GetTypeName(args.at(0)) + "')."}; + BoolIntDoubleHelper helper{}; + ADT_LET_CONST_REF(cmp_ret, + helper.template BinaryFunc(lhs, rhs)); + ADT_LET_CONST_REF(cmp, cmp_ret.template TryGet()); + return cmp ? args.at(0) : args.at(1); +} + +Result GetAttr(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string() + "getattr() takes 2 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF( + ret, interpreter->InterpretCall(builtin_symbol::GetAttr{}, args)); + return ret; +} + +Result SetAttr(axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 3) << adt::errors::TypeError{ + std::string() + "setattr() takes 3 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(func, + interpreter->InterpretCall(builtin_symbol::SetAttr{}, + {args.at(0), args.at(1)})); + ADT_LET_CONST_REF(ret, + interpreter->InterpretCall(func, {args.at(1), args.at(2)})); + return ret; +} + +} // namespace ap::axpr diff --git a/paddle/ap/src/axpr/core_expr.cc b/paddle/ap/src/axpr/core_expr.cc new file mode 100644 index 00000000000000..91891fb3c5ad1d --- /dev/null +++ b/paddle/ap/src/axpr/core_expr.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/axpr/core_expr.h" +#include +#include +#include +#include "paddle/ap/include/axpr/core_expr_builder.h" + +namespace ap::axpr { + +namespace { + +std::string AtomicExprToSExpression(const Atomic& core_expr) { + return core_expr.Match( + [](const Symbol& symbol) { return symbol.Name(); }, + [](const adt::Nothing) { return std::string("()"); }, + [](const bool c) { return c ? std::string("#t") : std::string("#f"); }, + [](const int64_t c) { return std::to_string(c); }, + [](const double c) { return std::to_string(c); }, + [](const std::string& str) { + std::ostringstream ss; + ss << std::quoted(str); + return ss.str(); + }, + [](const Lambda& lambda) { + std::ostringstream ss; + ss << "(lambda ["; + int i = 0; + for (const auto& arg : lambda->args) { + if (i++ > 0) { + ss << " "; + } + ss << arg.value(); + } + ss << "] "; + ss << lambda->body.ToSExpression(); + ss << ")"; + return ss.str(); + }); +} + +std::string ComposedCallExprToSExpression( + const ComposedCallAtomic& core_expr) { + std::ostringstream ss; + ss << "("; + ss << AtomicExprToSExpression(core_expr->outer_func); + ss << " "; + ss << AtomicExprToSExpression(core_expr->inner_func); + for (const auto& arg : core_expr->args) { + ss << " "; + ss << AtomicExprToSExpression(arg); + } + ss << ")"; + return ss.str(); +} + +} // namespace + +std::string CoreExpr::ToSExpression() const { + return Match( + [&](const Atomic& core_expr) { + return AtomicExprToSExpression(core_expr); + }, + [&](const ComposedCallAtomic& core_expr) { + return ComposedCallExprToSExpression(core_expr); + }); +} + +} // namespace ap::axpr diff --git a/paddle/ap/src/axpr/exception_method_class.cc b/paddle/ap/src/axpr/exception_method_class.cc new file mode 100644 index 00000000000000..e585abad4d224e --- /dev/null +++ b/paddle/ap/src/axpr/exception_method_class.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/axpr/exception_method_class.h" + +namespace ap::axpr { + +struct ExceptionMethodClass { + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + ss << self.value().class_name() << ": " << self.value().msg(); + return ss.str(); + } +}; + +axpr::TypeImpl> GetExceptionClass() { + static auto cls( + axpr::MakeBuiltinClass("Exception", [&](const auto& Yield) { + Yield("__str__", &ExceptionMethodClass::ToString); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::axpr diff --git a/paddle/ap/src/axpr/interpreter.cc b/paddle/ap/src/axpr/interpreter.cc new file mode 100644 index 00000000000000..7c17004b0cb8ce --- /dev/null +++ b/paddle/ap/src/axpr/interpreter.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/axpr/interpreter.h" +#include "paddle/ap/include/axpr/cps_interpreter.h" + +namespace ap::axpr { + +adt::Result Interpreter::Interpret( + const Lambda& lambda, const std::vector& args) { + CpsInterpreter cps_interpreter{builtin_frame_attr_map_, circlable_ref_list_}; + return cps_interpreter.Interpret(lambda, args); +} + +adt::Result Interpreter::Interpret( + const axpr::Value& function, const std::vector& args) { + CpsInterpreter cps_interpreter{builtin_frame_attr_map_, circlable_ref_list_}; + return cps_interpreter.Interpret(function, args); +} + +adt::Result Interpreter::InterpretModule( + const Frame& const_global_frame, + const Lambda& lambda) { + CpsInterpreter cps_interpreter{builtin_frame_attr_map_, circlable_ref_list_}; + return cps_interpreter.InterpretModule(const_global_frame, lambda); +} + +} // namespace ap::axpr diff --git a/paddle/ap/src/axpr/s_expr.cc b/paddle/ap/src/axpr/s_expr.cc new file mode 100644 index 00000000000000..b27399056a4ac5 --- /dev/null +++ b/paddle/ap/src/axpr/s_expr.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/axpr/s_expr.h" +#include +#include +#include + +namespace ap::axpr { + +namespace { + +std::string AtomicExprToSExpression(const Atomic& s_expr) { + return s_expr.Match( + [](const tVar& var) { return var.value(); }, + [](const adt::Nothing&) { return std::string("()"); }, + [](const bool c) { return c ? std::string("#t") : std::string("#f"); }, + [](const int64_t c) { return std::to_string(c); }, + [](const double c) { return std::to_string(c); }, + [](const std::string& str) { + std::ostringstream ss; + ss << std::quoted(str); + return ss.str(); + }, + [](const Lambda& lambda) { + std::ostringstream ss; + ss << "(lambda ["; + int i = 0; + for (const auto& arg : lambda->args) { + if (i++ > 0) { + ss << " "; + } + ss << arg.value(); + } + ss << "] "; + ss << lambda->body.ToSExpression(); + ss << ")"; + return ss.str(); + }); +} + +std::string SListExprToSExpression(const SList& s_expr) { + std::ostringstream ss; + int i = 0; + ss << "("; + for (const auto& child : s_expr->children) { + if (i++ > 0) { + ss << " "; + } + ss << child.ToSExpression(); + } + ss << ")"; + return ss.str(); +} + +} // namespace + +std::string SExpr::ToSExpression() const { + return Match( + [&](const Atomic& s_expr) { + return AtomicExprToSExpression(s_expr); + }, + [&](const SList& s_expr) { + return SListExprToSExpression(s_expr); + }); +} + +} // namespace ap::axpr diff --git a/paddle/ap/src/code_gen/code_gen_result_method_class.cc b/paddle/ap/src/code_gen/code_gen_result_method_class.cc new file mode 100644 index 00000000000000..3f5a2f67512053 --- /dev/null +++ b/paddle/ap/src/code_gen/code_gen_result_method_class.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/code_gen/code_gen_result_method_class.h" + +namespace ap::code_gen { + +template +struct TypeImplCodeGenResultMethodClass { + using This = TypeImplCodeGenResultMethodClass; + using Self = axpr::TypeImpl>; + + static adt::Result Construct(const ValueT& self_val, + const std::vector& args) { + return This{}.Make(self_val, args); + } + + adt::Result Make(const ValueT& self_val, + const std::vector& packed_args_val) { + ADT_LET_CONST_REF( + empty_self, + self_val.template TryGet>()); + const auto& packed_args = axpr::CastToPackedArgs(packed_args_val); + const auto& [args, kwargs] = *packed_args; + ADT_LET_CONST_REF(module_val, kwargs->Get("module")) + << adt::errors::TypeError{ + std::string() + + "the constructor of 'CodeGenResult' missing keyword argument " + "'module' of type 'CodeModule'."}; + ADT_LET_CONST_REF(m, axpr::Get(module_val)) + << adt::errors::TypeError{ + std::string() + + "the constructor of 'CodeGenResult' missing keyword argument " + "'module' of type 'CodeModule'."}; + ADT_LET_CONST_REF( + kernel_dispatch_func, + kwargs->template TryGet>( + "kernel_dispatch_func")) + << adt::errors::TypeError{ + std::string() + + "the constructor of 'CodeGenResult' missing keyword argument " + "'kernel_dispatch_func' of type 'Function'."}; + std::optional> + kernel_dispatch_const_data; + if (kwargs->Has("kernel_dispatch_const_data")) { + ADT_LET_CONST_REF( + data, + kwargs->template TryGet>( + "kernel_dispatch_const_data")) + << adt::errors::TypeError{ + std::string() + + "the constructor of 'CodeGenResult' needs keyword argument " + "'kernel_dispatch_const_data' of type " + "'BuiltinSerializableAttrMap'."}; + kernel_dispatch_const_data = data; + } else { + kernel_dispatch_const_data = axpr::AttrMap{}; + } + ADT_CHECK(kernel_dispatch_const_data.has_value()); + return empty_self.type.New(CodeGenResult{ + m, kernel_dispatch_func, kernel_dispatch_const_data.value()}); + } +}; + +axpr::TypeImpl> +GetCodeGenResultClass() { + using TypeImplMethods = TypeImplCodeGenResultMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "CodeGenResult", [&](const auto& Define) { + Define("__init__", &TypeImplMethods::Construct); + })); + using Self = CodeGenResult; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::code_gen diff --git a/paddle/ap/src/code_module/code_module_method_class.cc b/paddle/ap/src/code_module/code_module_method_class.cc new file mode 100644 index 00000000000000..dad863ca372389 --- /dev/null +++ b/paddle/ap/src/code_module/code_module_method_class.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/code_module/code_module_method_class.h" + +namespace ap::code_module { + +template +struct TypeImplCodeModuleMethodClass { + using This = TypeImplCodeModuleMethodClass; + using Self = axpr::TypeImpl; + + static adt::Result Make(const std::vector& args) { + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string("the constructor of 'CodeModule' takes 2 arguments. but ") + + std::to_string(args.size()) + "were given."}; + const auto& list = args.at(0).Match( + [&](const adt::List& l) -> adt::List { return l; }, + [&](const auto& impl) -> adt::List { + return adt::List{ValueT{impl}}; + }); + adt::List func_declares; + func_declares->reserve(list->size()); + for (const auto& elt : *list) { + ADT_LET_CONST_REF(func_declare, axpr::Get(elt)) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of constructor of 'CodeModule' should be a " + "'FuncDeclare' object or a list of 'FuncDeclare' object."}; + func_declares->emplace_back(func_declare); + } + ADT_LET_CONST_REF(source_code, SourceCode::CastFromAxprValue(args.at(1))) + << adt::errors::TypeError{std::string() + + "the argument 2 of CodeModule() should be a " + "'Project' (not " + + axpr::GetTypeName(args.at(1)) + ") object"}; + return CodeModule{func_declares, source_code}; + } +}; + +template +adt::Result InitCodeModule(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF( + empty_self, + self_val.template TryGet>()); + ADT_LET_CONST_REF(m, TypeImplCodeModuleMethodClass::Make(args)); + return empty_self.type.New(m); +} + +axpr::TypeImpl> GetCodeModuleClass() { + static auto cls(axpr::MakeBuiltinClass( + "CodeModule", [&](const auto& DoEach) { + DoEach("__init__", &InitCodeModule); + })); + using Self = CodeModule; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::code_module diff --git a/paddle/ap/src/code_module/directory_method_class.cc b/paddle/ap/src/code_module/directory_method_class.cc new file mode 100644 index 00000000000000..9e87b343fac6c8 --- /dev/null +++ b/paddle/ap/src/code_module/directory_method_class.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/code_module/directory_method_class.h" + +namespace ap::code_module { + +axpr::TypeImpl> GetDirectoryClass(); + +struct TypeDirectoryClassMethodClass { + static adt::Result New(const axpr::Value&, + const std::vector& args) { + axpr::AttrMap dentry2file{}; + int i = 0; + for (const auto& arg : args) { + ++i; + ADT_LET_CONST_REF(pair, arg.template CastTo>()) + << adt::errors::TypeError{ + std::string() + "the argument of " + std::to_string(i) + + " Directory() should be a [str, Project.Directory | " + "Project.FileContent | Project.SoftLink]" + ", but " + + axpr::GetTypeName(arg) + " were given"}; + ADT_CHECK(pair->size() == 2) << adt::errors::TypeError{ + std::string() + "the argument of " + std::to_string(i) + + " Directory() should be a [str, Project.Directory | " + "Project.FileContent | Project.SoftLink]" + ", but its length is " + + std::to_string(pair->size())}; + ADT_LET_CONST_REF(dentry, pair->at(0).template CastTo()) + << adt::errors::TypeError{ + std::string() + "the argument of " + std::to_string(i) + + " Directory() only acepts list of [str, Project.Directory | " + "Project.FileContent | Project.SoftLink]" + ". but the first of pair is a " + + axpr::GetTypeName(pair->at(0))}; + ADT_LET_CONST_REF(file, File::CastFromAxprValue(pair->at(1))) + << adt::errors::TypeError{ + std::string() + + "Directory() only acepts list of [str, Project.Directory | " + "Project.FileContent | Project.SoftLink]" + ", but the second of pair is a " + + axpr::GetTypeName(pair->at(1))}; + dentry2file->Set(dentry, file); + } + return GetDirectoryClass().New(Directory{dentry2file}); + } +}; + +axpr::TypeImpl> GetDirectoryClass() { + static auto cls( + axpr::MakeBuiltinClass("Directory", [&](const auto& DoEach) { + DoEach("__init__", &TypeDirectoryClassMethodClass::New); + })); + return axpr::MakeGlobalNaiveClassOps>(cls); +} + +} // namespace ap::code_module diff --git a/paddle/ap/src/code_module/file_content_method_class.cc b/paddle/ap/src/code_module/file_content_method_class.cc new file mode 100644 index 00000000000000..9639650f74b743 --- /dev/null +++ b/paddle/ap/src/code_module/file_content_method_class.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/code_module/file_content_method_class.h" + +namespace ap::code_module { + +axpr::TypeImpl> GetFileContentClass(); + +struct TypeFileContentMethodClass { + static adt::Result New(const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "FileContent() takes 1 argument, but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(file_content, args.at(0).template CastTo()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of FileContent() should a a str, but " + + axpr::GetTypeName(args.at(0)) + " were given"}; + return GetFileContentClass().New(FileContent{file_content}); + } +}; + +axpr::TypeImpl> GetFileContentClass() { + static auto cls(axpr::MakeBuiltinClass( + "FileContent", [&](const auto& DoEach) { + DoEach("__init__", &TypeFileContentMethodClass::New); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::code_module diff --git a/paddle/ap/src/code_module/func_declare_method_class.cc b/paddle/ap/src/code_module/func_declare_method_class.cc new file mode 100644 index 00000000000000..c8d0fb82dfca14 --- /dev/null +++ b/paddle/ap/src/code_module/func_declare_method_class.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/code_module/func_declare_method_class.h" + +namespace ap::code_module { + +template +struct TypeImplFuncDeclareMethodClass { + using This = TypeImplFuncDeclareMethodClass; + using Self = axpr::TypeImpl; + + static adt::Result Make(const std::vector& args) { + ADT_CHECK(args.size() == 3) << adt::errors::TypeError{ + std::string("the constructor of FuncDeclare takes 3 arguments but ") + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(ret_type, CastToArgType(args.at(0))) + << adt::errors::TypeError{std::string() + + "the argument 1 of FuncDeclare() should be a " + "'DataType or PointerType'"}; + ADT_LET_CONST_REF(func_id, axpr::TryGetImpl(args.at(1))) + << adt::errors::TypeError{std::string() + + "the argument 2 of " + "FuncDeclare() should be a 'str'"}; + ADT_LET_CONST_REF(arg_types, GetArgTypes(args.at(2))); + return FuncDeclare{ret_type, func_id, arg_types}; + } + + static Result> GetArgTypes(const ValueT& val) { + ADT_LET_CONST_REF(list, axpr::TryGetImpl>(val)) + << adt::errors::TypeError{std::string() + + "the argument 2 of construct of FuncDeclare " + "should be a list of DataType " + "or PointerType."}; + adt::List ret; + ret->reserve(list->size()); + for (const auto& elt : *list) { + ADT_LET_CONST_REF(arg_type, CastToArgType(elt)) + << adt::errors::TypeError{std::string() + + "the argument 2 of construct of " + "FuncDeclare should be a list of DataType " + "or PointerType."}; + ret->emplace_back(arg_type); + } + return ret; + } +}; + +template +adt::Result InitFuncDeclare(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF( + empty_self, + self_val.template TryGet>()); + ADT_LET_CONST_REF(func_declare, + TypeImplFuncDeclareMethodClass::Make(args)); + return empty_self.type.New(func_declare); +} + +axpr::TypeImpl> GetFuncDeclareClass() { + static auto cls(axpr::MakeBuiltinClass( + "FuncDeclare", [&](const auto& DoEach) { + DoEach("__init__", &InitFuncDeclare); + })); + using Self = FuncDeclare; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::code_module diff --git a/paddle/ap/src/code_module/package_method_class.cc b/paddle/ap/src/code_module/package_method_class.cc new file mode 100644 index 00000000000000..9ffc7d0d41bed7 --- /dev/null +++ b/paddle/ap/src/code_module/package_method_class.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/code_module/package_method_class.h" + +namespace ap::code_module { + +axpr::TypeImpl> GetPackageClass(); + +struct TypePackageClassMethodClass { + static adt::Result New( + const axpr::Value&, const std::vector& args_vec) { + const auto& packed_args = axpr::CastToPackedArgs(args_vec); + const auto& [args, kwargs] = *packed_args; + ADT_CHECK(args->empty()) << adt::errors::TypeError{ + std::string() + "Package() takes no positional argument, bug " + + std::to_string(args->size()) + " were given"}; + axpr::AttrMap dentry2file{}; + ADT_LET_CONST_REF(directory_val, kwargs->Get("nested_files")) + << adt::errors::TypeError{ + std::string() + + "Package() need the keyword argument 'nested_files'"}; + ADT_LET_CONST_REF(directory, + directory_val.template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the keyword argument 'nested_files' of Package() should be a " + "Directory, but " + + axpr::GetTypeName(directory_val) + " were given"}; + ADT_LET_CONST_REF(api_wrapper_so_relative_path_val, + kwargs->Get("api_wrapper_so_relative_path")) + << adt::errors::TypeError{std::string() + + "Package() need the keyword argument " + "'api_wrapper_so_relative_path'"}; + ADT_LET_CONST_REF( + api_wrapper_so_relative_path, + api_wrapper_so_relative_path_val.template CastTo()) + << adt::errors::TypeError{ + std::string() + + "the keyword argument 'api_wrapper_so_relative_path' of " + "Package() should be a str, but " + + axpr::GetTypeName(api_wrapper_so_relative_path_val) + + " were given"}; + ADT_LET_CONST_REF(main_so_relative_path_val, + kwargs->Get("main_so_relative_path")) + << adt::errors::TypeError{ + std::string() + + "Package() need the keyword argument 'main_so_relative_path'"}; + ADT_LET_CONST_REF(main_so_relative_path, + main_so_relative_path_val.template CastTo()) + << adt::errors::TypeError{ + std::string() + + "the keyword argument 'main_so_relative_path' of " + "Package() should be a str, but " + + axpr::GetTypeName(main_so_relative_path_val) + " were given"}; + axpr::AttrMap others; + if (kwargs->Has("others")) { + ADT_LET_CONST_REF(others_val, kwargs->Get("others")) + << adt::errors::TypeError{ + std::string() + + "Package() need the keyword argument 'others'"}; + ADT_LET_CONST_REF( + others_attrs, + others_val.template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the keyword argument 'others' of Package() should be a " + "BuiltinSerializableAttrMap, but " + + axpr::GetTypeName(others_val) + " were given"}; + others = others_attrs; + } + return GetPackageClass().New(Package{directory, + api_wrapper_so_relative_path, + main_so_relative_path, + others}); + } +}; + +axpr::TypeImpl> GetPackageClass() { + static auto cls( + axpr::MakeBuiltinClass("Package", [&](const auto& DoEach) { + DoEach("__init__", &TypePackageClassMethodClass::New); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::code_module diff --git a/paddle/ap/src/code_module/project_method_class.cc b/paddle/ap/src/code_module/project_method_class.cc new file mode 100644 index 00000000000000..20e0775b7fe935 --- /dev/null +++ b/paddle/ap/src/code_module/project_method_class.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/code_module/project_method_class.h" + +namespace ap::code_module { + +axpr::TypeImpl> GetProjectClass(); + +struct TypeProjectClassMethodClass { + static adt::Result New( + const axpr::Value&, const std::vector& args_vec) { + const auto& packed_args = axpr::CastToPackedArgs(args_vec); + const auto& [args, kwargs] = *packed_args; + ADT_CHECK(args->empty()) << adt::errors::TypeError{ + std::string() + "Project() takes no positional argument, bug " + + std::to_string(args->size()) + " were given"}; + axpr::AttrMap dentry2file{}; + ADT_LET_CONST_REF(directory_val, kwargs->Get("nested_files")) + << adt::errors::TypeError{ + std::string() + + "Project() need the keyword argument 'nested_files'"}; + ADT_LET_CONST_REF(directory, + directory_val.template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the keyword argument 'nested_files' of Project() should be a " + "Project.Directory, but " + + axpr::GetTypeName(directory_val) + " were given"}; + ADT_LET_CONST_REF(compile_cmd_val, kwargs->Get("compile_cmd")) + << adt::errors::TypeError{ + std::string() + + "Project() need the keyword argument 'compile_cmd'"}; + ADT_LET_CONST_REF(compile_cmd, + compile_cmd_val.template CastTo()) + << adt::errors::TypeError{std::string() + + "the keyword argument 'compile_cmd' of " + "Project() should be a str, but " + + axpr::GetTypeName(compile_cmd_val) + + " were given"}; + ADT_LET_CONST_REF(so_relative_path_val, kwargs->Get("so_relative_path")) + << adt::errors::TypeError{ + std::string() + + "Project() need the keyword argument 'so_relative_path'"}; + ADT_LET_CONST_REF(so_relative_path, + so_relative_path_val.template CastTo()) + << adt::errors::TypeError{std::string() + + "the keyword argument 'so_relative_path' of " + "Project() should be a str, but " + + axpr::GetTypeName(so_relative_path_val) + + " were given"}; + axpr::AttrMap others; + if (kwargs->Has("others")) { + ADT_LET_CONST_REF(others_val, kwargs->Get("others")) + << adt::errors::TypeError{ + std::string() + + "Project() need the keyword argument 'others'"}; + ADT_LET_CONST_REF( + others_attrs, + others_val.template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the keyword argument 'others' of Project() should be a " + "BuiltinSerializableAttrMap, but " + + axpr::GetTypeName(others_val) + " were given"}; + others = others_attrs; + } + return GetProjectClass().New( + Project{directory, compile_cmd, so_relative_path, others}); + } +}; + +axpr::TypeImpl> GetProjectClass() { + static auto cls( + axpr::MakeBuiltinClass("Project", [&](const auto& DoEach) { + DoEach("__init__", &TypeProjectClassMethodClass::New); + DoEach("FileContent", GetFileContentClass()); + DoEach("SoftLink", GetSoftLinkClass()); + DoEach("Directory", GetDirectoryClass()); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::code_module diff --git a/paddle/ap/src/code_module/soft_link_method_class.cc b/paddle/ap/src/code_module/soft_link_method_class.cc new file mode 100644 index 00000000000000..a470df17754859 --- /dev/null +++ b/paddle/ap/src/code_module/soft_link_method_class.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/code_module/soft_link_method_class.h" + +namespace ap::code_module { + +axpr::TypeImpl> GetSoftLinkClass(); + +struct TypeSoftLinkMethodClass { + static adt::Result New(const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "SoftLink() takes 1 argument, but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(target_relative_path, + args.at(0).template CastTo()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of SoftLink() should a a str, but " + + axpr::GetTypeName(args.at(0)) + " were given"}; + return GetSoftLinkClass().New(SoftLink{target_relative_path}); + } +}; + +axpr::TypeImpl> GetSoftLinkClass() { + static auto cls( + axpr::MakeBuiltinClass("SoftLink", [&](const auto& DoEach) { + DoEach("__init__", &TypeSoftLinkMethodClass::New); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::code_module diff --git a/paddle/ap/src/drr/drr_ctx_method_class.cc b/paddle/ap/src/drr/drr_ctx_method_class.cc new file mode 100644 index 00000000000000..9c2591fcb3733b --- /dev/null +++ b/paddle/ap/src/drr/drr_ctx_method_class.cc @@ -0,0 +1,233 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/drr_ctx_method_class.h" +#include "paddle/ap/include/axpr/callable_helper.h" + +namespace ap::drr { + +struct DrrCtxMethodClass { + using This = DrrCtxMethodClass; + using Self = drr::DrrCtx; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + static adt::Result StaticInitPassName( + axpr::InterpreterBase* interpreter, + const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(pass_name, args.at(0).template CastTo()) + << adt::errors::TypeError{ + std::string() + + "DrrCtx.init_pass_name() missing str typed argument 1"}; + self.shared_ptr()->pass_name = pass_name; + return adt::Nothing{}; + } + + static adt::Result StaticSetDrrPassType( + axpr::InterpreterBase* interpreter, + const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(drr_pass_type, args.at(0).template CastTo()) + << adt::errors::TypeError{ + std::string() + + "DrrCtx.set_drr_pass_type() missing str typed argument 1"}; + if (drr_pass_type == "abstract_drr_pass_type") { + self.shared_ptr()->drr_pass_type = drr::AbstractDrrPassType{}; + } else if (drr_pass_type == "reified_drr_pass_type") { + self.shared_ptr()->drr_pass_type = drr::ReifiedDrrPassType{}; + } else if (drr_pass_type == "access_topo_drr_pass_type") { + self.shared_ptr()->drr_pass_type = drr::AccessTopoDrrPassType{}; + } else { + return adt::errors::TypeError{ + std::string() + "invalid drr_pass_type '" + drr_pass_type + + "'. valid drr pass types " + "abstract_drr_pass_type/reified_drr_pass_type/" + "access_topo_drr_pass_type "}; + } + return adt::Nothing{}; + } + + static adt::Result StaticInitSourcePattern( + axpr::InterpreterBase* interpreter, + const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(!self->source_pattern_ctx.has_value()); + ADT_CHECK(args.size() == 1); + const auto& def_source_pattern = args.at(0); + auto node_arena = std::make_shared>(); + SourcePatternCtx source_pattern_ctx{ + node_arena, + OpPatternCtx{ + node_arena, std::map{}, self.shared_ptr()}, + TensorPatternCtx{ + node_arena, std::map{}, self.shared_ptr()}}; + self.shared_ptr()->source_pattern_ctx = source_pattern_ctx; + DrrValueHelper helper{}; + ADT_RETURN_IF_ERR(interpreter->InterpretCall( + def_source_pattern, + {helper.CastToAxprValue(SrcPtn(source_pattern_ctx->op_pattern_ctx)), + helper.CastToAxprValue( + SrcPtn(source_pattern_ctx->tensor_pattern_ctx))})); + return adt::Nothing{}; + } + + static adt::Result StaticInitConstraintFunc( + axpr::InterpreterBase* interpreter, + const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + ADT_CHECK(axpr::CallableHelper{}.IsCallable(args.at(0))) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of DrrCtx.init_constraint_func() should be a " + "callable object"}; + self.shared_ptr()->constraint_func = args.at(0); + return adt::Nothing{}; + } + + static adt::Result StaticInitResultPattern( + axpr::InterpreterBase* interpreter, + const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(!self->result_pattern_ctx.has_value()); + ADT_CHECK(args.size() == 1); + const auto& def_result_pattern = args.at(0); + auto node_arena = std::make_shared>(); + ResultPatternCtx result_pattern_ctx{ + node_arena, + OpPatternCtx{ + node_arena, std::map{}, self.shared_ptr()}, + TensorPatternCtx{ + node_arena, std::map{}, self.shared_ptr()}, + self->source_pattern_ctx.value()}; + self.shared_ptr()->result_pattern_ctx = result_pattern_ctx; + DrrValueHelper helper{}; + ADT_RETURN_IF_ERR(interpreter->InterpretCall( + def_result_pattern, + {helper.CastToAxprValue(ResPtn(result_pattern_ctx->op_pattern_ctx)), + helper.CastToAxprValue( + ResPtn(result_pattern_ctx->tensor_pattern_ctx))})); + return adt::Nothing{}; + } +}; + +struct TypeImplDrrCtxMethodClass { + using This = TypeImplDrrCtxMethodClass; + using Self = drr::Type; + + static adt::Result StaticConstruct( + axpr::InterpreterBase* interpreter, + const axpr::Value& instance_val, + const std::vector& args) { + return This{}.Construct(interpreter, instance_val, args); + } + + adt::Result Construct( + axpr::InterpreterBase* interpreter, + const axpr::Value& instance_val, + const std::vector& packed_args_val) { + ADT_LET_CONST_REF( + empty_self, + instance_val + .template CastTo>()); + DrrCtx self{interpreter->circlable_ref_list()}; + if (packed_args_val.size() == 0) { + return empty_self.type.New(self); + } + DrrValueHelper helper{}; + const auto& packed_args = axpr::CastToPackedArgs(packed_args_val); + const auto& [args, kwargs] = *packed_args; + ADT_CHECK(args->size() == 0) << adt::errors::TypeError{ + "the constructor of DrrCtx takes keyword arguments only."}; + { + ADT_LET_CONST_REF(def_source_pattern, kwargs->Get("source_pattern")); + auto node_arena = std::make_shared>(); + SourcePatternCtx source_pattern_ctx{ + node_arena, + OpPatternCtx{ + node_arena, std::map{}, self.shared_ptr()}, + TensorPatternCtx{ + node_arena, std::map{}, self.shared_ptr()}}; + self.shared_ptr()->source_pattern_ctx = source_pattern_ctx; + ADT_RETURN_IF_ERR(interpreter->InterpretCall( + def_source_pattern, + {helper.CastToAxprValue(SrcPtn(source_pattern_ctx->op_pattern_ctx)), + helper.CastToAxprValue( + SrcPtn(source_pattern_ctx->tensor_pattern_ctx))})); + } + { + ADT_LET_CONST_REF(def_result_pattern, kwargs->Get("result_pattern")); + auto node_arena = std::make_shared>(); + ResultPatternCtx result_pattern_ctx{ + node_arena, + OpPatternCtx{ + node_arena, std::map{}, self.shared_ptr()}, + TensorPatternCtx{ + node_arena, std::map{}, self.shared_ptr()}, + self->source_pattern_ctx.value()}; + self.shared_ptr()->result_pattern_ctx = result_pattern_ctx; + ADT_RETURN_IF_ERR(interpreter->InterpretCall( + def_result_pattern, + {helper.CastToAxprValue(ResPtn(result_pattern_ctx->op_pattern_ctx)), + helper.CastToAxprValue( + ResPtn(result_pattern_ctx->tensor_pattern_ctx))})); + } + return empty_self.type.New(self); + } +}; + +axpr::TypeImpl> GetDrrCtxClass() { + using Impl = drr::DrrCtxMethodClass; + using TImpl = TypeImplDrrCtxMethodClass; + using TT = drr::Type; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__init__", &TImpl::StaticConstruct); + Define("set_drr_pass_type", &Impl::StaticSetDrrPassType); + Define("init_pass_name", &Impl::StaticInitPassName); + Define("init_source_pattern", &Impl::StaticInitSourcePattern); + Define("init_constraint_func", &Impl::StaticInitConstraintFunc); + Define("init_result_pattern", &Impl::StaticInitResultPattern); + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/drr_interpreter.cc b/paddle/ap/src/drr/drr_interpreter.cc new file mode 100644 index 00000000000000..8383d8f6074d8e --- /dev/null +++ b/paddle/ap/src/drr/drr_interpreter.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/drr/drr_interpreter.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/interpreter.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/builtin_frame_util.h" +#include "paddle/ap/include/drr/drr_graph_descriptor.h" +#include "paddle/ap/include/drr/drr_node_descriptor.h" +#include "paddle/ap/include/drr/value.h" +#include "paddle/ap/include/drr/value_method_class.h" + +namespace ap::drr { + +namespace adt = ap::adt; + +namespace { + +using Function = ap::axpr::Value; + +using DrrNode = ap::drr::Node; +using DrrCtx = ap::drr::DrrCtx; + +} // namespace + +DrrInterpreter::DrrInterpreter( + const axpr::TypeImpl>& + backend_ir_ctx, + const std::weak_ptr& circlable_ref_list) + : interpreter_(ap::drr::MakeBuiltinFrameAttrMap( + [&](const auto& Insert) { Insert(backend_ir_ctx); }), + circlable_ref_list) {} + +adt::Result DrrInterpreter::InterpretDrrCtxMaker( + const Function& lambda, const std::vector& args) { + ADT_LET_CONST_REF(drr_ctx_val, interpreter_.Interpret(lambda, args)); + ADT_LET_CONST_REF(drr_ctx, drr_ctx_val.template CastTo()) + << adt::errors::TypeError{ + std::string() + + "drr function should return a 'DrrCtx' object but '" + + ap::axpr::GetTypeName(drr_ctx_val) + "' were given."}; + return drr_ctx; +} + +adt::Result DrrInterpreter::InterpretPass( + const Function& lambda, const std::string& drr_pass_name) { + ADT_LET_CONST_REF(drr_ctx_val, interpreter_.Interpret(lambda, {})); + ADT_LET_CONST_REF(drr_ctx, drr_ctx_val.template CastTo()) + << adt::errors::TypeError{ + std::string() + + "drr function should return a 'DrrCtx' object but '" + + ap::axpr::GetTypeName(drr_ctx_val) + "' were given."}; + return drr_ctx; +} + +adt::Result DrrInterpreter::InterpretPass( + const ap::axpr::ClassAttrs& cls) { + static ap::axpr::Lambda lambda([] { + ap::axpr::LambdaExprBuilder lmd; + const ap::axpr::AnfExpr anf_expr = lmd.Lambda({"cls"}, [](auto& ctx) { + auto& obj = ctx.Var("cls").Call(); + auto& method = obj.Attr("make_drr_ctx"); + auto& ret = method.Call(); + return ret; + }); + const auto& core_expr = ap::axpr::ConvertAnfExprToCoreExpr(anf_expr); + const auto& atomic = core_expr.Get>(); + return atomic.Get>(); + }()); + ap::axpr::Value cls_val{ + ap::axpr::TypeImpl>(cls)}; + ADT_LET_CONST_REF(drr_ctx_val, interpreter_.Interpret(lambda, {cls_val})); + ADT_LET_CONST_REF(drr_ctx, drr_ctx_val.template CastTo()) + << adt::errors::TypeError{ + std::string() + + "drr function should return a 'DrrCtx' object but '" + + ap::axpr::GetTypeName(drr_ctx_val) + "' were given."}; + return drr_ctx; +} + +ap::adt::Result DrrInterpreter::CreateDrrCtxByDrrPassObj( + const ap::axpr::Value& drr_pass_obj) { + static ap::axpr::Lambda lambda([] { + ap::axpr::LambdaExprBuilder lmd; + const ap::axpr::AnfExpr anf_expr = + lmd.Lambda({"drr_pass_obj"}, [](auto& ctx) { + auto& obj = ctx.Var("drr_pass_obj"); + auto& method = obj.Attr("make_drr_ctx"); + auto& ret = method.Call(); + return ret; + }); + const auto& core_expr = ap::axpr::ConvertAnfExprToCoreExpr(anf_expr); + const auto& atomic = core_expr.Get>(); + return atomic.Get>(); + }()); + ADT_LET_CONST_REF(drr_ctx_val, + interpreter_.Interpret(lambda, {drr_pass_obj})); + ADT_LET_CONST_REF(drr_ctx, drr_ctx_val.template CastTo()) + << adt::errors::TypeError{ + std::string() + + "drr function should return a 'DrrCtx' object but '" + + ap::axpr::GetTypeName(drr_ctx_val) + "' were given."}; + return drr_ctx; +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/native_ir_op_declare_method_class.cc b/paddle/ap/src/drr/native_ir_op_declare_method_class.cc new file mode 100644 index 00000000000000..3db0b27efa1845 --- /dev/null +++ b/paddle/ap/src/drr/native_ir_op_declare_method_class.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/native_ir_op_declare_method_class.h" + +namespace ap::drr { + +struct SrcPtnNativeIrOpDeclareMethodClassImpl { + using Self = drr::tSrcPtn>; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } +}; + +axpr::TypeImpl> +GetSrcPtnNativeIrOpDeclareClass() { + using Impl = SrcPtnNativeIrOpDeclareMethodClassImpl; + using TT = drr::Type>>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +struct ResPtnNativeIrOpDeclareMethodClass { + using Self = drr::tResPtn>; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } +}; + +axpr::TypeImpl> +GetResPtnNativeIrOpDeclareClass() { + using Impl = ResPtnNativeIrOpDeclareMethodClass; + using TT = drr::Type>>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/native_ir_op_method_class.cc b/paddle/ap/src/drr/native_ir_op_method_class.cc new file mode 100644 index 00000000000000..b24d4839c7104b --- /dev/null +++ b/paddle/ap/src/drr/native_ir_op_method_class.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/native_ir_op_method_class.h" + +namespace ap::drr { + +struct NativeIrOpMethodClass { + using Self = drr::NativeIrOp; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } +}; + +axpr::TypeImpl> GetNativeIrOpClass() { + using Impl = NativeIrOpMethodClass; + using TT = drr::Type>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/native_ir_value_method_class.cc b/paddle/ap/src/drr/native_ir_value_method_class.cc new file mode 100644 index 00000000000000..215183c8283c1f --- /dev/null +++ b/paddle/ap/src/drr/native_ir_value_method_class.cc @@ -0,0 +1,150 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/native_ir_value_method_class.h" + +namespace ap::drr { + +struct SrcPtnNativeIrValueMethodClassImpl { + using Self = drr::tSrcPtn>; + using This = SrcPtnNativeIrValueMethodClassImpl; + + static adt::Result GetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(attr_name, args.at(0).template CastTo()); + if (attr_name == "type") { + ADT_LET_CONST_REF(opt_type, + OpTensorPatternCtxHelper{}.GetOptType(self.value())); + if (opt_type.has_value()) { + return opt_type.value(); + } else { + return adt::Nothing{}; + } + } else { + return adt::errors::AttributeError{ + std::string() + "SrcPtnNativeIrValue '" + self.value()->name + + "' has no attribute '" + attr_name + "'"}; + } + } + + static adt::Result SetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(attr_name, args.at(0).template CastTo()); + const auto& attr_val = args.at(1); + if (attr_name == "type") { + ADT_RETURN_IF_ERR( + drr::OpTensorPatternCtxHelper{}.SetType(self.value(), attr_val)); + return adt::Nothing{}; + } else { + return adt::errors::AttributeError{ + std::string() + "SrcPtnNativeIrValue '" + self.value()->name + + "' has no attribute '" + attr_name + "'"}; + } + } + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + static adt::Result Starred( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return adt::errors::TypeError{ + std::string() + + "Only SrcPtnPackedIrValue and ResPtnPackedIrValue tensors can be " + "unpacked. tensor '" + + self.value()->name + "' is of type 'SrcPtnNativeIrValue'"}; + } +}; + +axpr::TypeImpl> +GetSrcPtnNativeIrValueClass() { + using Impl = SrcPtnNativeIrValueMethodClassImpl; + using TT = drr::Type>>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__starred__", &Impl::Starred); + Define("__getattr__", &Impl::GetAttr); + Define("__setattr__", &Impl::SetAttr); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +struct ResPtnNativeIrValueMethodClass { + using Self = drr::tResPtn>; + using This = ResPtnNativeIrValueMethodClass; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + static adt::Result Starred( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return adt::errors::TypeError{ + std::string() + + "Only SrcPtnPackedIrValue and ResPtnPackedIrValue tensors can be " + "unpacked. tensor '" + + self.value()->name + "' is of type 'ResPtnNativeIrValue'"}; + } +}; + +axpr::TypeImpl> +GetResPtnNativeIrValueClass() { + using Impl = ResPtnNativeIrValueMethodClass; + using TT = drr::Type>>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__starred__", &Impl::Starred); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/opt_packed_ir_op_declare_method_class.cc b/paddle/ap/src/drr/opt_packed_ir_op_declare_method_class.cc new file mode 100644 index 00000000000000..364d7542a88ac0 --- /dev/null +++ b/paddle/ap/src/drr/opt_packed_ir_op_declare_method_class.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/opt_packed_ir_op_declare_method_class.h" + +namespace ap::drr { + +struct OptPackedIrOpDeclareMethodClass { + using Self = drr::OptPackedIrOpDeclare; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } +}; + +axpr::TypeImpl> +GetOptPackedIrOpDeclareClass() { + using Impl = OptPackedIrOpDeclareMethodClass; + using TT = drr::Type>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/opt_packed_ir_op_method_class.cc b/paddle/ap/src/drr/opt_packed_ir_op_method_class.cc new file mode 100644 index 00000000000000..e5b2bb8a76a9cc --- /dev/null +++ b/paddle/ap/src/drr/opt_packed_ir_op_method_class.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/opt_packed_ir_op_method_class.h" + +namespace ap::drr { + +struct OptPackedIrOpMethodClass { + using Self = drr::OptPackedIrOp; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } +}; + +axpr::TypeImpl> +GetOptPackedIrOpClass() { + using Impl = OptPackedIrOpMethodClass; + using TT = drr::Type>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/packed_ir_op_declare_method_class.cc b/paddle/ap/src/drr/packed_ir_op_declare_method_class.cc new file mode 100644 index 00000000000000..325fc9550989c7 --- /dev/null +++ b/paddle/ap/src/drr/packed_ir_op_declare_method_class.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/packed_ir_op_declare_method_class.h" + +namespace ap::drr { + +struct SrcPtnPackedIrOpDeclareMethodClassImpl { + using Self = drr::tSrcPtn>; + using This = SrcPtnPackedIrOpDeclareMethodClassImpl; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } +}; + +struct ResPtnPackedIrOpDeclareMethodClassImpl { + using Self = drr::tResPtn>; + using This = ResPtnPackedIrOpDeclareMethodClassImpl; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } +}; + +axpr::TypeImpl> +GetSrcPtnPackedIrOpDeclareClass() { + using Impl = SrcPtnPackedIrOpDeclareMethodClassImpl; + using TT = drr::Type>>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +axpr::TypeImpl> +GetResPtnPackedIrOpDeclareClass() { + using Impl = ResPtnPackedIrOpDeclareMethodClassImpl; + using TT = drr::Type>>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/packed_ir_op_method_class.cc b/paddle/ap/src/drr/packed_ir_op_method_class.cc new file mode 100644 index 00000000000000..3dbb966d4f110f --- /dev/null +++ b/paddle/ap/src/drr/packed_ir_op_method_class.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/packed_ir_op_method_class.h" + +namespace ap::drr { + +struct PackedIrOpMethodClass { + using Self = drr::PackedIrOp; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } +}; + +axpr::TypeImpl> GetPackedIrOpClass() { + using Impl = PackedIrOpMethodClass; + using TT = drr::Type>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/packed_ir_value_method_class.cc b/paddle/ap/src/drr/packed_ir_value_method_class.cc new file mode 100644 index 00000000000000..c61fdec00380ea --- /dev/null +++ b/paddle/ap/src/drr/packed_ir_value_method_class.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/packed_ir_value_method_class.h" + +namespace ap::drr { + +struct SrcPtnPackedIrValueMethodClassImpl { + using Self = drr::tSrcPtn>; + using This = SrcPtnPackedIrValueMethodClassImpl; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + static adt::Result Starred( + const axpr::Value& self_val, const std::vector& args) { + return axpr::Starred{adt::List{self_val}}; + } +}; + +axpr::TypeImpl> +GetSrcPtnPackedIrValueClass() { + using Impl = SrcPtnPackedIrValueMethodClassImpl; + using TT = drr::Type>>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +struct StarredSrcPtnPackedIrValueMethodClassImpl { + using Self = drr::tStarred>>; + using This = StarredSrcPtnPackedIrValueMethodClassImpl; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().value().__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } +}; + +axpr::TypeImpl> +GetStarredSrcPtnPackedIrValueClass() { + using Impl = StarredSrcPtnPackedIrValueMethodClassImpl; + using TT = + drr::Type>>>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +struct ResPtnPackedIrValueMethodClassImpl { + using Self = drr::tResPtn>; + using This = ResPtnPackedIrValueMethodClassImpl; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + static adt::Result Starred( + const axpr::Value& self_val, const std::vector& args) { + return axpr::Starred{adt::List{self_val}}; + } +}; + +axpr::TypeImpl> +GetResPtnPackedIrValueClass() { + using Impl = ResPtnPackedIrValueMethodClassImpl; + using TT = drr::Type>>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +struct StarredResPtnPackedIrValueMethodClassImpl { + using This = StarredResPtnPackedIrValueMethodClassImpl; + using Self = drr::tStarred>>; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } +}; + +axpr::TypeImpl> +GetStarredResPtnPackedIrValueClass() { + using Impl = StarredResPtnPackedIrValueMethodClassImpl; + using TT = + drr::Type>>>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/res_ptn_op_pattern_ctx_method_class.cc b/paddle/ap/src/drr/res_ptn_op_pattern_ctx_method_class.cc new file mode 100644 index 00000000000000..466d89e3c76144 --- /dev/null +++ b/paddle/ap/src/drr/res_ptn_op_pattern_ctx_method_class.cc @@ -0,0 +1,225 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/res_ptn_op_pattern_ctx_method_class.h" +#include "paddle/ap/include/axpr/callable_helper.h" + +namespace ap::drr { + +struct ResPtnOpPatternCtxMethodClass { + using This = ResPtnOpPatternCtxMethodClass; + using ObjT = tResPtn; + using Self = ObjT; + using Helper = OpTensorPatternCtxHelper; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + static adt::Result SetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 2); + const auto& arg = args.at(0); + ADT_LET_CONST_REF(attr_name, arg.template CastTo()); + ADT_CHECK(!This{}.IsBasicAttrName(attr_name)) + << adt::errors::AttributeError{"can't set attribute '" + attr_name + + "'"}; + return MakeAndRegisterUnboundIrOp(self_val, args); + } + + static adt::Result MakeAndRegisterUnboundIrOp( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(op_uid, args.at(0).template CastTo()); + const auto& drr_value = DrrValueHelper{}.CastFromAxprValue(args.at(1)); + const auto& opt_ir_op = drr_value.DrrValueMatch( + [&](const tResPtn>& op) + -> adt::Result { + return UnboundPackedIrOp{op.value(), op_uid}; + }, + [&](const tResPtn>& op) + -> adt::Result { + return UnboundNativeIrOp{op.value(), op_uid}; + }, + [&](const OptPackedIrOpDeclare&) -> adt::Result { + return adt::errors::TypeError{ + std::string() + + "only 'ResPtnPackedIrOpDeclare' and 'ResPtnNativeIrOpDeclare' " + "supported for op name binding. '" + + axpr::GetTypeName(args.at(1)) + "' were given."}; + }, + [&](const auto&) -> adt::Result { + return adt::errors::TypeError{ + std::string() + + "only 'ResPtnPackedIrOpDeclare' and 'ResPtnNativeIrOpDeclare' " + "supported for op name binding. '" + + axpr::GetTypeName(args.at(1)) + "' were given."}; + }); + ADT_LET_CONST_REF(ir_op, opt_ir_op); + bool has_ir_op = Helper{}.HasIrOpByUid(self.value(), op_uid); + if (has_ir_op) { + ADT_RETURN_IF_ERR( + Helper{}.CheckIrOpNameByUid(self.value(), op_uid, ir_op)); + } else { + Helper{}.SetIrOpByUid(self.value(), op_uid, ir_op); + } + return adt::Nothing{}; + } + + static adt::Result GetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + const auto& arg = args.at(0); + ADT_LET_CONST_REF(self, self_val.template CastTo()); + using RetT = adt::Result; + ADT_LET_CONST_REF(attr_name, arg.template CastTo()); + ADT_CHECK(!This{}.IsBasicAttrName(attr_name)) << adt::errors::RuntimeError{ + std::string() + "Dead code encounterred. attr_name: " + attr_name}; + ADT_LET_CONST_REF(ir_op, Helper{}.GetIrOpByUid(self.value(), attr_name)); + const auto& convert_result = ir_op.Match( + [](const NativeIrOp& impl) -> adt::Result { + return impl; + }, + [](const PackedIrOp& impl) -> adt::Result { + return impl; + }, + [](const OptPackedIrOp& impl) -> adt::Result { + return adt::errors::KeyError{ + std::string() + + "OptPackedIrOp is not supported in result pattern."}; + }, + [](const UnboundNativeIrOp& x) -> adt::Result { + return ResPtn(x); + }, + [](const UnboundPackedIrOp& x) -> adt::Result { + return ResPtn(x); + }, + [](const UnboundOptPackedIrOp& x) -> adt::Result { + return adt::errors::KeyError{ + std::string() + + "UnboundOptPackedIrOp is not supported in result pattern."}; + }); + ADT_LET_CONST_REF(drr_value, convert_result); + return DrrValueHelper{}.CastToAxprValue(drr_value); + } + + static adt::Result StaticDeclareApPatternFusionOp( + const axpr::Value& self_val, const std::vector& args) { + return This{}.DeclareApPatternFusionOp(self_val, args); + } + + adt::Result DeclareApPatternFusionOp( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + + "ResPtnOpPatternCtx.ap_pattern_fusion_op takes 1 arguments. but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(kernel_define_lambda, CheckCallable(args.at(0))) + << adt::errors::TypeError{std::string() + + "argument 1 of o.ap_pattern_fusion_op should " + "be a function_code object."}; + auto data = + std::make_shared(kernel_define_lambda); + PackedIrOpDeclare op_declare{ + "ap_pattern_fusion_op", self.value().shared_ptr(), data}; + return DrrValueHelper{}.CastToAxprValue(ResPtn(op_declare)); + } + + static adt::Result StaticDeclareApNativeOp( + const axpr::Value& self_val, const std::vector& args) { + return This{}.DeclareApNativeOp(self_val, args); + } + + adt::Result DeclareApNativeOp( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + + "ResPtnOpPatternCtx.ap_native_op takes 1 arguments. but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(op_name, args.at(0).template CastTo()) + << adt::errors::TypeError{std::string() + + "argument 1 of o.ap_native_op should " + "be a str."}; + NativeIrOpDeclare op_declare{op_name, self.value().shared_ptr()}; + return DrrValueHelper{}.CastToAxprValue(ResPtn(op_declare)); + } + + adt::Result CheckCallable(const axpr::Value& val) { + ADT_CHECK(axpr::CallableHelper{}.IsCallable(val)) << adt::errors::TypeError{ + std::string() + + "the argument 1 of ResPtnOpPatternCtx.ap_pattern_fusion_op() should be " + "callable"}; + return val; + } + + static adt::Result DeclareNativeIrOp( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + + "ResPtnOpPatternCtx.ap_native_op takes 1 arguments. but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(op_name, args.at(0).template CastTo()); + NativeIrOpDeclare op_declare{op_name, self.value().shared_ptr()}; + return DrrValueHelper{}.CastToAxprValue(ResPtn(op_declare)); + } + + bool IsBasicAttrName(const std::string& attr_name) { + const auto& attr_getters = AttrGetters(); + return attr_getters.count(attr_name) > 0; + } + + const std::set& AttrGetters() { + static const std::set set{ + "ap_pattern_fusion_op", + }; + return set; + } +}; + +axpr::TypeImpl> +GetResPtnOpPatternCtxClass() { + using Impl = drr::ResPtnOpPatternCtxMethodClass; + using TT = drr::Type>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__getattr__", &Impl::GetAttr); + Define("__setattr__", &Impl::SetAttr); + Define("ap_pattern_fusion_op", &Impl::StaticDeclareApPatternFusionOp); + Define("ap_native_op", &Impl::StaticDeclareApNativeOp); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/res_ptn_tensor_pattern_ctx_method_class.cc b/paddle/ap/src/drr/res_ptn_tensor_pattern_ctx_method_class.cc new file mode 100644 index 00000000000000..4eaeb8f9acd07d --- /dev/null +++ b/paddle/ap/src/drr/res_ptn_tensor_pattern_ctx_method_class.cc @@ -0,0 +1,126 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/res_ptn_tensor_pattern_ctx_method_class.h" + +namespace ap::drr { + +struct ResPtnTensorPatternCtx { + using This = ResPtnTensorPatternCtx; + using ObjT = drr::tResPtn; + using Self = ObjT; + using Helper = OpTensorPatternCtxHelper; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + static adt::Result GetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + const auto& arg = args.at(0); + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_LET_CONST_REF(tensor_name, arg.template CastTo()); + const auto& opt_ir_value = + Helper{}.GetIrValueByUid(self.value(), tensor_name); + if (opt_ir_value.HasOkValue()) { + const auto& drr_value = opt_ir_value.GetOkValue().Match( + [](const auto& impl) -> DrrValue { return ResPtn(impl); }); + return DrrValueHelper{}.CastToAxprValue(drr_value); + } + ADT_LET_CONST_REF(drr_ctx_ptr, adt::WeakPtrLock(self.value()->drr_ctx)); + { + ADT_CHECK(drr_ctx_ptr->result_pattern_ctx.has_value()); + const auto& result_pattern_ctx = drr_ctx_ptr->result_pattern_ctx.value(); + const auto& internal_names = + result_pattern_ctx->internal_native_ir_value_names; + if (internal_names.count(tensor_name)) { + UnboundIrValue unbound_ir_value{tensor_name, + self.value().shared_ptr()}; + return DrrValueHelper{}.CastToAxprValue(unbound_ir_value); + } + } + const auto& src_tensor_ctx = + drr_ctx_ptr->source_pattern_ctx.value()->tensor_pattern_ctx; + ADT_LET_CONST_REF(src_ir_value, + Helper{}.GetIrValueByUid(src_tensor_ctx, tensor_name)) + << adt::errors::AttributeError{ + std::string() + "no source pattern binding tensor named '" + + tensor_name + "' found."}; + const auto& match_result = src_ir_value.Match( + [&](const NativeIrValue& impl) -> adt::Result { + ADT_LET_CONST_REF( + cloned, Helper{}.CloneIrValueDataAndRegister(self.value(), impl)); + return ResPtn(cloned); + }, + [&](const PackedIrValue& impl) -> adt::Result { + ADT_LET_CONST_REF( + cloned, Helper{}.CloneIrValueDataAndRegister(self.value(), impl)); + return ResPtn(cloned); + }, + [&](const auto&) -> adt::Result { + return adt::errors::AttributeError{ + std::string() + "no source pattern binding tensor named '" + + tensor_name + "' found."}; + }); + ADT_LET_CONST_REF(drr_value, match_result); + return DrrValueHelper{}.CastToAxprValue(drr_value); + } + + static adt::Result DeclareInternalNativeIrValue( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + const auto& arg = args.at(0); + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_LET_CONST_REF(ir_value_name, arg.template CastTo()); + ADT_LET_CONST_REF(drr_ctx, adt::WeakPtrLock(self.value()->drr_ctx)); + ADT_CHECK(drr_ctx->result_pattern_ctx.has_value()); + auto* result_pattern_ctx = + drr_ctx->result_pattern_ctx.value().shared_ptr().get(); + result_pattern_ctx->internal_native_ir_value_names.insert(ir_value_name); + return adt::Nothing{}; + } +}; + +axpr::TypeImpl> +GetResPtnTensorPatternCtxClass() { + using Impl = drr::ResPtnTensorPatternCtx; + using TT = drr::Type>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__getattr__", &Impl::GetAttr); + Define("declare_internal_native_ir_value", + &Impl::DeclareInternalNativeIrValue); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/res_ptn_unbound_native_ir_op_method_class.cc b/paddle/ap/src/drr/res_ptn_unbound_native_ir_op_method_class.cc new file mode 100644 index 00000000000000..e87015d403dca4 --- /dev/null +++ b/paddle/ap/src/drr/res_ptn_unbound_native_ir_op_method_class.cc @@ -0,0 +1,176 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/res_ptn_unbound_native_ir_op_method_class.h" +#include "paddle/ap/include/axpr/callable_helper.h" +#include "paddle/ap/include/drr/drr_pass_type_helper.h" + +namespace ap::drr { + +struct ResPtnUnboundNativeIrOpMethodClass { + using This = ResPtnUnboundNativeIrOpMethodClass; + using Self = tResPtn>; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + using Helper = OpTensorPatternCtxHelper; + + static adt::Result StaticCall( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return This{}.Call(self, args); + } + + adt::Result Call(const Self& self, + const std::vector& args) { + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string() + + "ResPtnUnboundNativeIrOp.__call__ takes 2 arguments. but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(input_vals, + args.at(0).template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the first argument of ResPtnUnboundNativeIrOp.__call__ should " + "be a list."}; + adt::List> inputs; + inputs->reserve(input_vals->size()); + for (const auto& input_val : *input_vals) { + ADT_LET_CONST_REF( + input, input_val.template CastTo>>()) + << adt::errors::TypeError{ + std::string() + + "unsupported operand types for " + "ResPtnUnboundNativeIrOp.__call__ inputs: '" + + axpr::GetTypeName(input_val) + "'."}; + inputs->emplace_back(input.value()); + } + ADT_LET_CONST_REF(output_vals, + args.at(1).template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the second argument of ResPtnUnboundNativeIrOp.__call__ should " + "be a list."}; + adt::List> outputs; + outputs->reserve(output_vals->size()); + for (const auto& output_val : *output_vals) { + ADT_LET_CONST_REF(valid_output, + ResPtnValidOutIrValue::CastFromAxprValue(output_val)) + << adt::errors::TypeError{ + std::string() + + "unsupported operand types for " + "ResPtnUnboundNativeIrOp.__call__ outputs: '" + + axpr::GetTypeName(output_val) + "'."}; + using RetT = adt::Result>; + ADT_LET_CONST_REF( + output, + valid_output.Match( + [&](const UnboundIrValue& impl) -> RetT { + return Helper{}.GetNativeIrValueByUnboundIrValue(impl); + }, + [&](const tResPtn>& impl) -> RetT { + return impl.value(); + })); + outputs->emplace_back(output); + } + ADT_RETURN_IF_ERR(CheckNoRedundantTensorNames(inputs, outputs)); + ADT_LET_CONST_REF(native_op, + Helper{}.GetNativeIrOpByUnboundNativeIrOp(self.value())); + Helper{}.ConnectIrOpAndIrValue(native_op, inputs, outputs); + return adt::Nothing{}; + } + + adt::Result CheckNoRedundantTensorNames( + const adt::List>& inputs, + const adt::List>& outputs) { + std::unordered_set existed_names; + for (const auto& output : *outputs) { + ADT_CHECK(existed_names.emplace(output->name).second) + << adt::errors::TypeError{std::string() + "redundant tensor name '" + + output->name + "' detected."}; + } + return adt::Ok{}; + } + + static adt::Result SetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(attr_name, args.at(0).template CastTo()); + const auto& attr_val = args.at(1); + ADT_LET_CONST_REF(support_reifying, This{}.SupportReifying(self)); + if (support_reifying) { + ADT_RETURN_IF_ERR( + attr_val.template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "an attribute of ResPtnNativeIrOp of abstract_drr_pass_type " + "should be a serializable `Function`(not a " + + axpr::GetTypeName(attr_val) + + "). op_name: " + self.value()->op_declare->op_name + + ", attr_name: " + attr_name}; + } else { + ADT_CHECK(axpr::CallableHelper{}.IsCallable(attr_val)) + << adt::errors::TypeError{std::string() + + "an attribute of ResPtnNativeIrOp should " + "be a callable getter. op_name: " + + self.value()->op_declare->op_name + + ", attr_name: " + attr_name}; + } + auto* attr_map = self.value()->op_declare->attr_map.shared_ptr().get(); + attr_map->Set(attr_name, attr_val); + return adt::Nothing{}; + } + + adt::Result SupportReifying(const Self& self) const { + ADT_LET_CONST_REF( + op_pattern_ctx, + adt::WeakPtrLock(self.value()->op_declare->op_pattern_ctx)); + ADT_LET_CONST_REF(drr_ctx, adt::WeakPtrLock(op_pattern_ctx->drr_ctx)); + return DrrPassTypeHelper{}.SupportReifying(drr_ctx->drr_pass_type); + } +}; + +axpr::TypeImpl> +GetResPtnUnboundNativeIrOpClass() { + using TT = drr::Type>>; + using Impl = ResPtnUnboundNativeIrOpMethodClass; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__call__", &Impl::StaticCall); + Define("__setattr__", &Impl::SetAttr); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/res_ptn_unbound_packed_ir_op_method_class.cc b/paddle/ap/src/drr/res_ptn_unbound_packed_ir_op_method_class.cc new file mode 100644 index 00000000000000..10954f7702e7ad --- /dev/null +++ b/paddle/ap/src/drr/res_ptn_unbound_packed_ir_op_method_class.cc @@ -0,0 +1,132 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/res_ptn_unbound_packed_ir_op_method_class.h" + +namespace ap::drr { + +struct ResPtnUnboundPackedIrOp { + using This = ResPtnUnboundPackedIrOp; + using Self = tResPtn>; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + using Helper = OpTensorPatternCtxHelper; + + static adt::Result StaticCall( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return This{}.Call(self, args); + } + + adt::Result Call(const Self& self, + const std::vector& args) { + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string() + + "ResPtnUnboundPackedIrOp.__call__ takes 2 arguments. but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(input_vals, + args.at(0).template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the first argument of ResPtnUnboundPackedIrOp.__call__ should " + "be a list."}; + adt::List inputs; + inputs->reserve(input_vals->size()); + for (const auto& input_val : *input_vals) { + ADT_LET_CONST_REF(input, CastToIrValue(input_val)); + inputs->emplace_back(input); + } + ADT_LET_CONST_REF(output_vals, + args.at(1).template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the second argument of ResPtnUnboundPackedIrOp.__call__ should " + "be a list."}; + adt::List outputs; + outputs->reserve(output_vals->size()); + for (const auto& output_val : *output_vals) { + ADT_LET_CONST_REF(output, CastToIrValue(output_val)); + outputs->emplace_back(output); + } + ADT_RETURN_IF_ERR(CheckNoRedundantTensorNames(inputs, outputs)); + ADT_LET_CONST_REF(packed_op, + Helper{}.GetPackedIrOpByUnboundPackedIrOp(self.value())); + Helper{}.ConnectIrOpAndIrValue(packed_op, inputs, outputs); + return adt::Nothing{}; + } + + adt::Result CheckNoRedundantTensorNames( + const adt::List& inputs, const adt::List& outputs) { + std::unordered_set existed_names; + for (const auto& input : *inputs) { + existed_names.insert(input.name()); + } + for (const auto& output : *outputs) { + ADT_CHECK(existed_names.emplace(output.name()).second) + << adt::errors::TypeError{std::string() + "redundant tensor name '" + + output.name() + "' detected."}; + } + return adt::Ok{}; + } + + adt::Result CastToIrValue(const axpr::Value& arg) { + DrrValueHelper helper{}; + return helper.CastFromAxprValue(arg).DrrValueMatch( + [&](const tResPtn>& value) + -> adt::Result { return value.value(); }, + [&](const tResPtn>& value) + -> adt::Result { return value.value(); }, + [&](const auto&) -> adt::Result { + return adt::errors::TypeError{std::string() + + "unsupported operand types for " + "ResPtnUnboundPackedIrOp.__call__: " + + axpr::GetTypeName(arg) + + ". only 'ResPtnNativeIrValue' and " + "'ResPtnPackedIrValue' supported. "}; + }); + } +}; + +axpr::TypeImpl> +GetResPtnUnboundPackedIrOpClass() { + using Impl = drr::ResPtnUnboundPackedIrOp; + using TT = drr::Type>>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__call__", &Impl::StaticCall); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/result_pattern_ctx_method_class.cc b/paddle/ap/src/drr/result_pattern_ctx_method_class.cc new file mode 100644 index 00000000000000..ad2c368fa0102e --- /dev/null +++ b/paddle/ap/src/drr/result_pattern_ctx_method_class.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/result_pattern_ctx_method_class.h" + +namespace ap::drr { + +struct ResultPatternCtxMethodClass { + using Self = drr::ResultPatternCtx; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } +}; + +axpr::TypeImpl> +GetResultPatternCtxClass() { + using Impl = ResultPatternCtxMethodClass; + using TT = drr::Type; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/source_pattern_ctx_method_class.cc b/paddle/ap/src/drr/source_pattern_ctx_method_class.cc new file mode 100644 index 00000000000000..010d603a0de84b --- /dev/null +++ b/paddle/ap/src/drr/source_pattern_ctx_method_class.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/source_pattern_ctx_method_class.h" + +namespace ap::drr { + +struct SourcePatternCtxMethodClass { + using Self = drr::SourcePatternCtx; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } +}; + +axpr::TypeImpl> +GetSourcePatternCtxClass() { + using Impl = SourcePatternCtxMethodClass; + using TT = drr::Type; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/src_ptn_op_pattern_ctx_method_class.cc b/paddle/ap/src/drr/src_ptn_op_pattern_ctx_method_class.cc new file mode 100644 index 00000000000000..5abca213f461b8 --- /dev/null +++ b/paddle/ap/src/drr/src_ptn_op_pattern_ctx_method_class.cc @@ -0,0 +1,203 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/src_ptn_op_pattern_ctx_method_class.h" +#include +#include "paddle/ap/include/drr/drr_pass_type_helper.h" + +namespace ap::drr { + +struct SrcPtnOpPatternCtxMethodClass { + using This = SrcPtnOpPatternCtxMethodClass; + using ObjT = tSrcPtn; + using Self = ObjT; + using Helper = OpTensorPatternCtxHelper; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + static adt::Result SetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 2); + const auto& arg = args.at(0); + ADT_LET_CONST_REF(attr_name, arg.template CastTo()); + ADT_CHECK(!IsBasicAttrName(attr_name)) << adt::errors::AttributeError{ + "can't set attribute '" + attr_name + "'"}; + return MakeAndRegisterUnboundIrOp(self_val, args); + } + + static adt::Result MakeAndRegisterUnboundIrOp( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(op_uid, args.at(0).template CastTo()); + const auto& drr_value = DrrValueHelper{}.CastFromAxprValue(args.at(1)); + const auto& opt_ir_op = drr_value.DrrValueMatch( + [&](const tSrcPtn>& op) + -> adt::Result { + return UnboundPackedIrOp{op.value(), op_uid}; + }, + [&](const OptPackedIrOpDeclare& op) -> adt::Result { + return UnboundOptPackedIrOp{op, op_uid}; + }, + [&](const tSrcPtn>& op) + -> adt::Result { + return UnboundNativeIrOp{op.value(), op_uid}; + }, + [&](const auto&) -> adt::Result { + return adt::errors::TypeError{ + std::string() + + "only 'SrcPtnPackedIrOpDeclare' and 'SrcPtnNativeIrOpDeclare' " + "supported for op name binding. '" + + axpr::GetTypeName(args.at(1)) + "' were given."}; + }); + ADT_LET_CONST_REF(ir_op, opt_ir_op); + bool has_ir_op = Helper{}.HasIrOpByUid(self.value(), op_uid); + if (has_ir_op) { + ADT_RETURN_IF_ERR( + Helper{}.CheckIrOpNameByUid(self.value(), op_uid, ir_op)); + } else { + Helper{}.SetIrOpByUid(self.value(), op_uid, ir_op); + } + return adt::Nothing{}; + } + + static adt::Result GetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + const auto& arg = args.at(0); + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_LET_CONST_REF(attr_name, arg.template CastTo()); + ADT_CHECK(!IsBasicAttrName(attr_name)) << adt::errors::RuntimeError{ + std::string() + "Dead code encounterred. attr_name: " + attr_name}; + ADT_LET_CONST_REF(ir_op, Helper{}.GetIrOpByUid(self.value(), attr_name)); + const auto& drr_value = ir_op.Match( + [](const NativeIrOp& impl) -> DrrValue { return impl; }, + [](const PackedIrOp& impl) -> DrrValue { return impl; }, + [](const OptPackedIrOp& impl) -> DrrValue { return impl; }, + [](const UnboundOptPackedIrOp& impl) -> DrrValue { + return impl; + }, + [](const UnboundNativeIrOp& x) -> DrrValue { + return SrcPtn(x); + }, + [](const UnboundPackedIrOp& x) -> DrrValue { + return SrcPtn(x); + }); + return DrrValueHelper{}.CastToAxprValue(drr_value); + } + + static adt::Result DeclareApTrivialFusionOp( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::optional> opt_func; + if (args.size() == 1) { + ADT_LET_CONST_REF( + func, + args.at(0) + .template CastTo>()); + opt_func = func; + } else { + ADT_CHECK(args.size() == 0) << adt::errors::TypeError{ + std::string() + + "SrcPtnOpPatternCtx.ap_trivial_fusion_op takes 1 or 0 arguments. " + "but " + + std::to_string(args.size()) + " were given."}; + } + auto ptr = std::make_shared(); + ptr->inner_source_pattern_func = opt_func; + std::shared_ptr op_declare_data{ptr}; + PackedIrOpDeclare op_declare{ + "ap_trivial_fusion_op", self.value().shared_ptr(), op_declare_data}; + return DrrValueHelper{}.CastToAxprValue(SrcPtn(op_declare)); + } + + static adt::Result DeclareOptionalApTrivialFusionOp( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_LET_CONST_REF(drr_ctx, adt::WeakPtrLock(self.value()->drr_ctx)); + ADT_CHECK( + DrrPassTypeHelper{}.SupportOptionalPackedOp(drr_ctx->drr_pass_type)); + ADT_CHECK(args.size() == 0) + << adt::errors::TypeError{std::string() + + "SrcPtnOpPatternCtx.optional_ap_trivial_" + "fusion_op takes 0 arguments. but " + + std::to_string(args.size()) + " were given."}; + OptPackedIrOpDeclare op_declare{ + "ap_trivial_fusion_op", self.value().shared_ptr(), std::nullopt}; + return DrrValueHelper{}.CastToAxprValue(op_declare); + } + + static adt::Result DeclareNativeIrOp( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + + "SrcPtnOpPatternCtx.ap_native_op takes 1 arguments. but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(op_name, args.at(0).template CastTo()); + NativeIrOpDeclare op_declare{op_name, self.value().shared_ptr()}; + return DrrValueHelper{}.CastToAxprValue(SrcPtn(op_declare)); + } + + static bool IsBasicAttrName(const std::string& attr_name) { + const auto& attr_getters = AttrGetters(); + return attr_getters.count(attr_name) > 0; + } + + static const std::set& AttrGetters() { + static const std::set set{ + "ap_trivial_fusion_op", + "optional_ap_trivial_fusion_op", + "ap_native_op", + }; + return set; + } +}; + +axpr::TypeImpl> +GetSrcPtnOpPatternCtxClass() { + using Impl = drr::SrcPtnOpPatternCtxMethodClass; + using TT = drr::Type>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("ap_trivial_fusion_op", &Impl::DeclareApTrivialFusionOp); + Define("optional_ap_trivial_fusion_op", + &Impl::DeclareOptionalApTrivialFusionOp); + Define("ap_native_op", &Impl::DeclareNativeIrOp); + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__getattr__", &Impl::GetAttr); + Define("__setattr__", &Impl::SetAttr); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/src_ptn_tensor_pattern_ctx_method_class.cc b/paddle/ap/src/drr/src_ptn_tensor_pattern_ctx_method_class.cc new file mode 100644 index 00000000000000..3837c9eb37942e --- /dev/null +++ b/paddle/ap/src/drr/src_ptn_tensor_pattern_ctx_method_class.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/src_ptn_tensor_pattern_ctx_method_class.h" + +namespace ap::drr { + +struct SrcPtnTensorPatternCtx { + using This = SrcPtnTensorPatternCtx; + using ObjT = drr::tSrcPtn; + using Self = ObjT; + using Helper = OpTensorPatternCtxHelper; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + static adt::Result GetAttr( + const axpr::Value& self_val, const std::vector& args) { + return This::GetOrCreateTensor(self_val, args); + } + + static adt::Result GetOrCreateTensor( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + const auto& arg = args.at(0); + ADT_LET_CONST_REF(self, self_val.template CastTo()); + if (arg.template Has()) { + return adt::Nothing{}; + } + ADT_LET_CONST_REF(tensor_name, arg.template CastTo()); + + const auto& opt_ir_value = + Helper{}.GetIrValueByUid(self.value(), tensor_name); + if (opt_ir_value.HasError()) { + UnboundIrValue unbound_ir_value{tensor_name, + self.value().shared_ptr()}; + return DrrValueHelper{}.CastToAxprValue(unbound_ir_value); + } + const auto& ir_value = opt_ir_value.GetOkValue(); + const auto& drr_value = ir_value.Match( + [](const auto& impl) -> DrrValue { return SrcPtn(impl); }); + return DrrValueHelper{}.CastToAxprValue(drr_value); + } +}; + +axpr::TypeImpl> +GetSrcPtnTensorPatternCtxClass() { + using Impl = drr::SrcPtnTensorPatternCtx; + using TT = drr::Type>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__getattr__", &Impl::GetAttr); + Define("get_or_create_tensor", &Impl::GetOrCreateTensor); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/src_ptn_unbound_native_ir_op_method_class.cc b/paddle/ap/src/drr/src_ptn_unbound_native_ir_op_method_class.cc new file mode 100644 index 00000000000000..87c62331f10601 --- /dev/null +++ b/paddle/ap/src/drr/src_ptn_unbound_native_ir_op_method_class.cc @@ -0,0 +1,210 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/src_ptn_unbound_native_ir_op_method_class.h" + +namespace ap::drr { + +using SrcPtnValidIrValueImpl = + std::variant, UnboundIrValue>; + +struct SrcPtnValidIrValue : public SrcPtnValidIrValueImpl { + using SrcPtnValidIrValueImpl::SrcPtnValidIrValueImpl; + + ADT_DEFINE_VARIANT_METHODS(SrcPtnValidIrValueImpl); + + const std::string& name() const { + return Match([](const auto& ir_value) -> const std::string& { + return ir_value->name; + }); + } +}; + +struct SrcPtnUnboundNativeIrOp { + using This = SrcPtnUnboundNativeIrOp; + using Self = tSrcPtn>; + + static adt::Result GetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(attr_name, args.at(0).template CastTo()); + const auto& op_declare = self.value()->op_declare; + const auto& attr_map = op_declare->attr_map; + ADT_CHECK(attr_map->Has(attr_name)) << adt::errors::AttributeError{ + std::string() + "SrcPtnUnboundNativeIrOp '" + op_declare->op_name + + "' has no attribute '" + attr_name + "'"}; + return attr_map->Get(attr_name); + } + + static adt::Result SetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(attr_name, args.at(0).template CastTo()); + const auto& attr_val = args.at(1); + auto* attr_map = self.value()->op_declare->attr_map.shared_ptr().get(); + attr_map->Set(attr_name, attr_val); + return adt::Nothing{}; + } + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + using Helper = OpTensorPatternCtxHelper; + + static adt::Result StaticCall( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return This{}.Call(self, args); + } + + adt::Result Call(const Self& self, + const std::vector& args) { + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string() + + "SrcPtnUnboundNativeIrOp.__call__ takes 2 arguments. but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(input_vals, + args.at(0).template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the first argument of SrcPtnUnboundNativeIrOp.__call__ should " + "be a list."}; + adt::List inputs; + inputs->reserve(input_vals->size()); + for (const auto& input_val : *input_vals) { + ADT_LET_CONST_REF(input, CastToSrcPtnValidIrValue(input_val)); + inputs->emplace_back(input); + } + ADT_LET_CONST_REF(output_vals, + args.at(1).template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the second argument of SrcPtnUnboundNativeIrOp.__call__ should " + "be a list."}; + adt::List> outputs; + outputs->reserve(output_vals->size()); + for (const auto& output_val : *output_vals) { + ADT_LET_CONST_REF( + output, output_val.template CastTo>()); + outputs->emplace_back(output); + } + ADT_RETURN_IF_ERR(CheckNoRedundantTensorNames(inputs, outputs)); + ADT_LET_CONST_REF(native_inputs, ConvertInputs(inputs)); + ADT_LET_CONST_REF(native_outputs, ConvertOutputs(outputs)); + ADT_LET_CONST_REF(native_op, + Helper{}.GetNativeIrOpByUnboundNativeIrOp(self.value())); + Helper{}.ConnectIrOpAndIrValue(native_op, native_inputs, native_outputs); + return adt::Nothing{}; + } + + adt::Result>> ConvertInputs( + const adt::List& inputs) { + adt::List> ret_inputs; + ret_inputs->reserve(inputs->size()); + using Native = NativeIrValue; + for (const auto& input : *inputs) { + const auto& opt_ret_input = input.Match( + [&](const NativeIrValue& ir_value) -> adt::Result { + return ir_value; + }, + [&](const UnboundIrValue& ir_value) + -> adt::Result { + return Helper{}.GetNativeIrValueByUnboundIrValue(ir_value); + }); + ADT_LET_CONST_REF(ret_input, opt_ret_input); + ret_inputs->emplace_back(ret_input); + } + return ret_inputs; + } + + adt::Result>> ConvertOutputs( + const adt::List>& outputs) { + adt::List> ret_outputs; + ret_outputs->reserve(outputs->size()); + for (const auto& output : *outputs) { + ADT_LET_CONST_REF(ret_output, + Helper{}.GetNativeIrValueByUnboundIrValue(output)); + ret_outputs->emplace_back(ret_output); + } + return ret_outputs; + } + + adt::Result CheckNoRedundantTensorNames( + const adt::List& inputs, + const adt::List>& outputs) { + std::unordered_set existed_names; + for (const auto& input : *inputs) { + existed_names.insert(input.name()); + } + for (const auto& output : *outputs) { + ADT_CHECK(existed_names.emplace(output->name).second) + << adt::errors::TypeError{std::string() + "redundant tensor name '" + + output->name + "' detected."}; + } + return adt::Ok{}; + } + + adt::Result CastToSrcPtnValidIrValue( + const axpr::Value& arg) { + DrrValueHelper helper{}; + return helper.CastFromAxprValue(arg).DrrValueMatch( + [&](const tSrcPtn>& value) + -> adt::Result { return value.value(); }, + [&](const UnboundIrValue& value) + -> adt::Result { return value; }, + [&](const auto&) -> adt::Result { + return adt::errors::TypeError{ + std::string() + + "unsupported operand types for " + "SrcPtnUnboundNativeIrOp.__call__: " + + axpr::GetTypeName(arg) + + ". only 'SrcPtnNativeIrValue' and 'UnboundIrValue' supported. "}; + }); + } +}; + +axpr::TypeImpl> +GetSrcPtnUnboundNativeIrOpClass() { + using Impl = drr::SrcPtnUnboundNativeIrOp; + using TT = drr::Type>>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__call__", &Impl::StaticCall); + Define("__getattr__", &Impl::GetAttr); + Define("__setattr__", &Impl::SetAttr); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/src_ptn_unbound_packed_ir_op_method_class.cc b/paddle/ap/src/drr/src_ptn_unbound_packed_ir_op_method_class.cc new file mode 100644 index 00000000000000..a9382aa62b6c7c --- /dev/null +++ b/paddle/ap/src/drr/src_ptn_unbound_packed_ir_op_method_class.cc @@ -0,0 +1,307 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/src_ptn_unbound_packed_ir_op_method_class.h" + +namespace ap::drr { + +struct SrcPtnUnboundPackedIrOp { + using This = SrcPtnUnboundPackedIrOp; + using Self = tSrcPtn>; + + using DrrNode = drr::Node; + using DrrNativeIrValue = drr::NativeIrValue; + using DrrPackedIrValue = drr::PackedIrValue; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.value().__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + static adt::Result SetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(attr_name, args.at(0).template CastTo()); + if (attr_name == "inner_source_pattern_func") { + ADT_LET_CONST_REF( + func, + args.at(0) + .template CastTo>()); + auto* raw_ptr = self.value()->op_declare->data.value().get(); + auto* ptr = dynamic_cast(raw_ptr); + ADT_CHECK(ptr != nullptr); + ptr->inner_source_pattern_func = func; + } else { + return adt::errors::AttributeError{ + std::string() + "SrcPtnUnboundPackedIrOp object has no attribute '" + + attr_name + "'"}; + } + return adt::Nothing{}; + } + + using Helper = OpTensorPatternCtxHelper; + + static adt::Result StaticCall( + axpr::InterpreterBase* interpreter, + const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return This{}.Call(interpreter, self, args); + } + + adt::Result InitInnerSourcePatternCtx( + axpr::InterpreterBase* interpreter, const Self& self) { + const auto& op_declare = self.value()->op_declare; + ADT_LET_CONST_REF(op_pattern_ctx, + adt::WeakPtrLock(op_declare->op_pattern_ctx)); + const auto& drr_ctx_impl = op_pattern_ctx->drr_ctx; + auto node_arena = std::make_shared>(); + SourcePatternCtx inner_source_pattern_ctx{ + node_arena, + OpPatternCtx{node_arena, std::map{}, drr_ctx_impl}, + TensorPatternCtx{ + node_arena, std::map{}, drr_ctx_impl}}; + ADT_CHECK(op_declare->data.has_value()); + auto* raw_ptr = op_declare->data.value().get(); + auto* ptr = dynamic_cast(raw_ptr); + ADT_CHECK(ptr != nullptr); + if (!ptr->inner_source_pattern_func.has_value()) { + return adt::Nothing{}; + } + ADT_CHECK(!ptr->inner_source_pattern_ctx.has_value()); + ptr->inner_source_pattern_ctx = inner_source_pattern_ctx; + ADT_CHECK(ptr->inner_source_pattern_func.has_value()); + const auto& inner_source_pattern_func = + ptr->inner_source_pattern_func.value(); + DrrValueHelper helper{}; + ADT_RETURN_IF_ERR(interpreter->InterpretCall( + inner_source_pattern_func, + {helper.CastToAxprValue( + SrcPtn(inner_source_pattern_ctx->op_pattern_ctx)), + helper.CastToAxprValue( + SrcPtn(inner_source_pattern_ctx->tensor_pattern_ctx))})); + return adt::Nothing{}; + } + + adt::Result Call(axpr::InterpreterBase* interpreter, + const Self& self, + const std::vector& args) { + ADT_RETURN_IF_ERR(InitInnerSourcePatternCtx(interpreter, self)); + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string() + + "SrcPtnUnboundPackedIrOp.__call__ takes 2 arguments. but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(input_vals, + args.at(0).template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the first argument of SrcPtnUnboundPackedIrOp.__call__ should " + "be a list."}; + adt::List inputs; + inputs->reserve(input_vals->size()); + for (const auto& input_val : *input_vals) { + ADT_LET_CONST_REF(input, CastToSrcPtnValidInIrValue(input_val)); + inputs->emplace_back(input); + } + ADT_LET_CONST_REF(output_vals, + args.at(1).template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the second argument of SrcPtnUnboundPackedIrOp.__call__ should " + "be a list."}; + adt::List outputs; + outputs->reserve(output_vals->size()); + for (const auto& output_val : *output_vals) { + ADT_LET_CONST_REF(output, CastToSrcPtnValidOutIrValue(output_val)); + outputs->emplace_back(output); + } + ADT_RETURN_IF_ERR(CheckNoRedundantTensorNames(inputs, outputs)); + ADT_LET_CONST_REF(packed_inputs, ConvertInputs(inputs)); + { + ADT_LET_CONST_REF(num_packed_ir_value_inputs, + GetNumPackedIrValues(packed_inputs)); + ADT_CHECK(num_packed_ir_value_inputs <= 1) << adt::errors::TypeError{ + std::string() + + "SrcPtnUnboundPackedIrOp.__call__(): only 0 or 1 packed ir value " + "inputs are supported. " + + std::to_string(num_packed_ir_value_inputs) + " inputs were given."}; + } + ADT_LET_CONST_REF(packed_outputs, ConvertOutputs(outputs)); + { + ADT_LET_CONST_REF(num_packed_ir_value_outputs, + GetNumPackedIrValues(packed_outputs)); + ADT_CHECK(num_packed_ir_value_outputs <= 1) << adt::errors::TypeError{ + std::string() + + "SrcPtnUnboundPackedIrOp.__call__(): only 0 or 1 packed ir value " + "outputs are supported. " + + std::to_string(num_packed_ir_value_outputs) + " outputs were given."}; + } + ADT_LET_CONST_REF(packed_op, + Helper{}.GetPackedIrOpByUnboundPackedIrOp(self.value())); + ADT_RETURN_IF_ERR(Helper{}.ConnectIrOpAndIrValue( + packed_op, packed_inputs, packed_outputs)); + return adt::Nothing{}; + } + + adt::Result GetNumPackedIrValues( + const adt::List& ir_values) const { + std::size_t count = 0; + for (const auto& ir_value : *ir_values) { + count += ir_value.template Has(); + } + return count; + } + + adt::Result> ConvertInputs( + const adt::List& inputs) { + adt::List ret_inputs; + ret_inputs->reserve(inputs->size()); + using IrVal = IrValue; + for (const auto& input : *inputs) { + const auto& opt_ret_input = input.Match( + [&](const NativeIrValue& ir_value) -> adt::Result { + return ir_value; + }, + [&](const PackedIrValue& ir_value) -> adt::Result { + return ir_value; + }, + [&](const UnboundIrValue& ir_value) -> adt::Result { + ADT_LET_CONST_REF( + ret, Helper{}.GetNativeIrValueByUnboundIrValue(ir_value)); + return ret; + }, + [&](const UnboundPackedIrValue& ir_value) + -> adt::Result { + ADT_LET_CONST_REF( + ret, Helper{}.GetPackedIrValueByUnboundPackedIrValue(ir_value)); + return ret; + }); + ADT_LET_CONST_REF(ret_input, opt_ret_input); + ret_inputs->emplace_back(ret_input); + } + return ret_inputs; + } + + adt::Result> ConvertOutputs( + const adt::List& outputs) { + adt::List ret_outputs; + using IrVal = IrValue; + ret_outputs->reserve(outputs->size()); + for (const auto& output : *outputs) { + const auto& opt_ret_output = output.Match( + [&](const UnboundIrValue& ir_value) -> adt::Result { + ADT_LET_CONST_REF( + ret, Helper{}.GetNativeIrValueByUnboundIrValue(ir_value)); + return ret; + }, + [&](const UnboundPackedIrValue& ir_value) + -> adt::Result { + ADT_LET_CONST_REF( + ret, Helper{}.GetPackedIrValueByUnboundPackedIrValue(ir_value)); + return ret; + }); + ADT_LET_CONST_REF(ret_output, opt_ret_output); + ret_outputs->emplace_back(ret_output); + } + return ret_outputs; + } + + adt::Result CheckNoRedundantTensorNames( + const adt::List& inputs, + const adt::List& outputs) { + std::unordered_set existed_names; + for (const auto& input : *inputs) { + existed_names.insert(input.name()); + } + for (const auto& output : *outputs) { + ADT_CHECK(existed_names.emplace(output.name()).second) + << adt::errors::TypeError{std::string() + "redundant tensor name '" + + output.name() + "' detected."}; + } + return adt::Ok{}; + } + + adt::Result CastToSrcPtnValidInIrValue( + const axpr::Value& arg) { + return DrrValueHelper{}.CastFromAxprValue(arg).DrrValueMatch( + [&](const tSrcPtn>& value) + -> adt::Result { + return SrcPtnValidInIrValue{value.value()}; + }, + [&](const tSrcPtn>& value) + -> adt::Result { + return SrcPtnValidInIrValue{value.value()}; + }, + [&](const UnboundIrValue& value) + -> adt::Result { return value; }, + [&](const UnboundPackedIrValue& value) + -> adt::Result { return value; }, + [&](const auto&) -> adt::Result { + return adt::errors::TypeError{ + std::string() + + "unsupported operand types for the first arguments of " + "SrcPtnUnboundPackedIrOp.__call__: " + + axpr::GetTypeName(arg) + + ". only 'SrcPtnPackedIrValue' and 'UnboundIrValue' supported. "}; + }); + } + + adt::Result CastToSrcPtnValidOutIrValue( + const axpr::Value& arg) { + return DrrValueHelper{}.CastFromAxprValue(arg).DrrValueMatch( + [&](const UnboundIrValue& value) + -> adt::Result { return value; }, + [&](const UnboundPackedIrValue& value) + -> adt::Result { return value; }, + [&](const auto&) -> adt::Result { + return adt::errors::TypeError{ + std::string() + + "unsupported operand types for the second arguments of " + "SrcPtnUnboundPackedIrOp.__call__: " + + axpr::GetTypeName(arg) + + ". only 'SrcPtnPackedIrValue' and 'UnboundIrValue' supported. "}; + }); + } +}; + +axpr::TypeImpl> +GetSrcPtnUnboundPackedIrOpClass() { + using Impl = drr::SrcPtnUnboundPackedIrOp; + using TT = drr::Type>>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__call__", &Impl::StaticCall); + Define("__setattr__", &Impl::SetAttr); + })); + using Self = Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/unbound_ir_value_method_class.cc b/paddle/ap/src/drr/unbound_ir_value_method_class.cc new file mode 100644 index 00000000000000..a439916ed2bd7b --- /dev/null +++ b/paddle/ap/src/drr/unbound_ir_value_method_class.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/unbound_ir_value_method_class.h" + +namespace ap::drr { + +struct UnboundIrValueMethodClassImpl { + using This = UnboundIrValueMethodClassImpl; + using Self = UnboundIrValue; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + static adt::Result SetAttr( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(attr_name, args.at(0).template CastTo()); + const auto& attr_val = args.at(1); + if (attr_name == "type") { + ADT_RETURN_IF_ERR(OpTensorPatternCtxHelper{}.SetType(self, attr_val)); + return adt::Nothing{}; + } else { + return adt::errors::AttributeError{ + std::string(axpr::GetTypeName(self_val)) + " '" + self->name + + "' has no attribute '" + attr_name + "'"}; + } + } + + static adt::Result Starred( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + UnboundPackedIrValue packed_ir_value{self->name, + self->tensor_pattern_ctx}; + DrrValueHelper helper{}; + axpr::Value starred{helper.CastToAxprValue(packed_ir_value)}; + return axpr::Starred{adt::List{starred}}; + } +}; + +axpr::TypeImpl> +GetUnboundIrValueClass() { + using Impl = UnboundIrValueMethodClassImpl; + using TT = drr::Type>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + Define("__starred__", &Impl::Starred); + Define("__setattr__", &Impl::SetAttr); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/unbound_opt_packed_ir_op_method_class.cc b/paddle/ap/src/drr/unbound_opt_packed_ir_op_method_class.cc new file mode 100644 index 00000000000000..4bde3c896544d0 --- /dev/null +++ b/paddle/ap/src/drr/unbound_opt_packed_ir_op_method_class.cc @@ -0,0 +1,270 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/unbound_opt_packed_ir_op_method_class.h" + +namespace ap::drr { + +struct UnboundOptPackedIrOpMethodClass { + using This = UnboundOptPackedIrOpMethodClass; + using Self = UnboundOptPackedIrOp; + + using DrrNode = drr::Node; + using DrrNativeIrValue = drr::NativeIrValue; + using DrrPackedIrValue = drr::PackedIrValue; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } + + using Helper = OpTensorPatternCtxHelper; + + static adt::Result StaticCall( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return This{}.Call(self, args); + } + + adt::Result Call(const Self& self, + const std::vector& args) { + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string() + + "UnboundOptPackedIrOp.__call__ takes 2 arguments. but " + + std::to_string(args.size()) + " were given."}; + ADT_LET_CONST_REF(input_vals, + args.at(0).template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the first argument of UnboundOptPackedIrOp.__call__ should " + "be a list."}; + adt::List inputs; + inputs->reserve(input_vals->size()); + for (const auto& input_val : *input_vals) { + ADT_LET_CONST_REF(input, CastToSrcPtnValidInIrValue(input_val)); + inputs->emplace_back(input); + } + ADT_LET_CONST_REF(output_vals, + args.at(1).template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the second argument of UnboundOptPackedIrOp.__call__ should " + "be a list."}; + adt::List outputs; + outputs->reserve(output_vals->size()); + for (const auto& output_val : *output_vals) { + ADT_LET_CONST_REF(output, CastToSrcPtnValidOutIrValue(output_val)); + outputs->emplace_back(output); + } + ADT_RETURN_IF_ERR(CheckNoRedundantTensorNames(inputs, outputs)); + ADT_LET_CONST_REF(packed_inputs, ConvertInputs(inputs)); + { + ADT_LET_CONST_REF(num_packed_ir_value_inputs, + GetNumIrValues(packed_inputs)); + ADT_CHECK(num_packed_ir_value_inputs <= 1) << adt::errors::TypeError{ + std::string() + + "UnboundOptPackedIrOp.__call__(): only 0 or 1 packed ir value inputs " + "are supported. " + + std::to_string(num_packed_ir_value_inputs) + " inputs were given."}; + } + { + ADT_LET_CONST_REF(num_native_ir_value_inputs, + GetNumIrValues(packed_inputs)); + ADT_CHECK(num_native_ir_value_inputs == 1) << adt::errors::TypeError{ + std::string() + + "UnboundOptPackedIrOp.__call__(): only sole native ir value input " + "are supported, but" + + std::to_string(num_native_ir_value_inputs) + + " native ir value inputs were given."}; + } + ADT_LET_CONST_REF(packed_outputs, ConvertOutputs(outputs)); + { + ADT_LET_CONST_REF( + num_packed_ir_value_outputs, + this->template GetNumIrValues(packed_outputs)); + ADT_CHECK(num_packed_ir_value_outputs <= 1) << adt::errors::TypeError{ + std::string() + + "UnboundOptPackedIrOp.__call__(): only 0 or 1 packed ir value " + "outputs are supported. " + + std::to_string(num_packed_ir_value_outputs) + " outputs were given."}; + } + { + ADT_LET_CONST_REF( + num_native_ir_value_outputs, + this->template GetNumIrValues(packed_outputs)); + ADT_CHECK(num_native_ir_value_outputs == 1) << adt::errors::TypeError{ + std::string() + + "UnboundOptPackedIrOp.__call__(): only sole native ir value output " + "are supported, but" + + std::to_string(num_native_ir_value_outputs) + + " native ir value outputs were given."}; + } + ADT_LET_CONST_REF(opt_packed_op, + Helper{}.GetOptPackedIrOpByUnboundOptPackedIrOp(self)); + ADT_RETURN_IF_ERR(Helper{}.ConnectIrOpAndIrValue( + opt_packed_op, packed_inputs, packed_outputs)); + return adt::Nothing{}; + } + + template + adt::Result GetNumIrValues( + const adt::List& ir_values) const { + std::size_t count = 0; + for (const auto& ir_value : *ir_values) { + count += ir_value.template Has(); + } + return count; + } + + adt::Result> ConvertInputs( + const adt::List& inputs) { + adt::List ret_inputs; + ret_inputs->reserve(inputs->size()); + using IrVal = IrValue; + for (const auto& input : *inputs) { + const auto& opt_ret_input = input.Match( + [&](const NativeIrValue& ir_value) -> adt::Result { + return ir_value; + }, + [&](const PackedIrValue& ir_value) -> adt::Result { + return ir_value; + }, + [&](const UnboundIrValue& ir_value) -> adt::Result { + ADT_LET_CONST_REF( + ret, Helper{}.GetNativeIrValueByUnboundIrValue(ir_value)); + return ret; + }, + [&](const UnboundPackedIrValue& ir_value) + -> adt::Result { + ADT_LET_CONST_REF( + ret, Helper{}.GetPackedIrValueByUnboundPackedIrValue(ir_value)); + return ret; + }); + ADT_LET_CONST_REF(ret_input, opt_ret_input); + ret_inputs->emplace_back(ret_input); + } + return ret_inputs; + } + + adt::Result> ConvertOutputs( + const adt::List& outputs) { + adt::List ret_outputs; + using IrVal = IrValue; + ret_outputs->reserve(outputs->size()); + for (const auto& output : *outputs) { + const auto& opt_ret_output = output.Match( + [&](const UnboundIrValue& ir_value) -> adt::Result { + ADT_LET_CONST_REF( + ret, Helper{}.GetNativeIrValueByUnboundIrValue(ir_value)); + return ret; + }, + [&](const UnboundPackedIrValue& ir_value) + -> adt::Result { + ADT_LET_CONST_REF( + ret, Helper{}.GetPackedIrValueByUnboundPackedIrValue(ir_value)); + return ret; + }); + ADT_LET_CONST_REF(ret_output, opt_ret_output); + ret_outputs->emplace_back(ret_output); + } + return ret_outputs; + } + + adt::Result CheckNoRedundantTensorNames( + const adt::List& inputs, + const adt::List& outputs) { + std::unordered_set existed_names; + for (const auto& input : *inputs) { + existed_names.insert(input.name()); + } + for (const auto& output : *outputs) { + ADT_CHECK(existed_names.emplace(output.name()).second) + << adt::errors::TypeError{std::string() + "redundant tensor name '" + + output.name() + "' detected."}; + } + return adt::Ok{}; + } + + adt::Result CastToSrcPtnValidInIrValue( + const axpr::Value& arg) { + DrrValueHelper helper{}; + return helper.CastFromAxprValue(arg).DrrValueMatch( + [&](const tSrcPtn>& value) + -> adt::Result { + return SrcPtnValidInIrValue{value.value()}; + }, + [&](const tSrcPtn>& value) + -> adt::Result { + return SrcPtnValidInIrValue{value.value()}; + }, + [&](const UnboundIrValue& value) + -> adt::Result { return value; }, + [&](const UnboundPackedIrValue& value) + -> adt::Result { return value; }, + [&](const auto&) -> adt::Result { + return adt::errors::TypeError{ + std::string() + + "unsupported operand types for the first arguments of " + "UnboundOptPackedIrOp.__call__: " + + axpr::GetTypeName(arg) + + ". only 'SrcPtnPackedIrValue' and 'UnboundIrValue' supported. "}; + }); + } + + adt::Result CastToSrcPtnValidOutIrValue( + const axpr::Value& arg) { + DrrValueHelper helper{}; + return helper.CastFromAxprValue(arg).DrrValueMatch( + [&](const UnboundIrValue& value) + -> adt::Result { return value; }, + [&](const UnboundPackedIrValue& value) + -> adt::Result { return value; }, + [&](const auto&) -> adt::Result { + return adt::errors::TypeError{ + std::string() + + "unsupported operand types for the second arguments of " + "UnboundOptPackedIrOp.__call__: " + + axpr::GetTypeName(arg) + + ". only 'SrcPtnPackedIrValue' and 'UnboundIrValue' supported. "}; + }); + } +}; + +axpr::TypeImpl> +GetUnboundOptPackedIrOpClass() { + using Impl = UnboundOptPackedIrOpMethodClass; + using TT = drr::Type>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/drr/unbound_packed_ir_value_method_class.cc b/paddle/ap/src/drr/unbound_packed_ir_value_method_class.cc new file mode 100644 index 00000000000000..288f51586b423e --- /dev/null +++ b/paddle/ap/src/drr/unbound_packed_ir_value_method_class.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/drr/unbound_packed_ir_value_method_class.h" + +namespace ap::drr { + +struct UnboundPackedIrValueMethodClass { + using Self = drr::UnboundPackedIrValue; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << "<" << drr::Type{}.Name() << " object at " << ptr << ">"; + return ss.str(); + } + + static adt::Result Hash(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + return reinterpret_cast(ptr); + } +}; + +axpr::TypeImpl> +GetUnboundPackedIrValueClass() { + using Impl = UnboundPackedIrValueMethodClass; + using TT = drr::Type>; + static auto cls( + axpr::MakeBuiltinClass(TT{}.Name(), [&](const auto& Define) { + Define("__str__", &Impl::ToString); + Define("__hash__", &Impl::Hash); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::drr diff --git a/paddle/ap/src/index_expr/index_closure.cc b/paddle/ap/src/index_expr/index_closure.cc new file mode 100644 index 00000000000000..84b2aba34b6aa5 --- /dev/null +++ b/paddle/ap/src/index_expr/index_closure.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/index_expr/index_closure.h" +#include "paddle/ap/include/index_expr/op_index_tuple_expr_signature.h" +#include "paddle/ap/include/index_expr/valid_index_expr_builder.h" + +namespace ap::index_expr { + +adt::Result OrderedOneofIndexClosureImpl::operator()( + const IndexTupleExpr& indexes_expr) const { + size_t count = 0; + for (const auto& [_, lambdas] : nice2index_lambdas) { + for (const auto& lambda : lambdas) { + const auto& res = CallLambda(lambda, indexes_expr); + if (res.Has()) { + return res.Get(); + } + ++count; + } + } + return adt::errors::ValueError{ + std::string() + + "all index closure failed. tried count: " + std::to_string(count)}; +} + +adt::Result OrderedOneofIndexClosureImpl::CallLambda( + const Lambda& lambda, const IndexTupleExpr& indexes_expr) const { + axpr::BuiltinClassInstance instance{GetIndexTupleExprClass(), + indexes_expr}; + const std::vector args{closure_data.ctx, + closure_data.inputs_meta, + closure_data.outputs_meta, + closure_data.in_vars, + Val{instance}}; + const auto& opt_ret = (*this->interpreter)(lambda, args); + ADT_RETURN_IF_ERR(opt_ret); + const auto& ret = opt_ret.GetOkValue(); + return ret.template CastTo(); +} + +namespace { + +template +adt::Result OpIndexesTransformApply( + const OpIndexesTransformSignature& indexes_transform_signature, + const IndexesTransformApplyT& IndexesTransformApply) { + InIndexTupleExprSignature in_sig; + for (const auto& transform : + *indexes_transform_signature.in_signature.descriptors) { + const auto& converted = IndexesTransformApply(transform); + ADT_RETURN_IF_ERR(converted); + in_sig.descriptors->emplace_back(converted.GetOkValue()); + } + OutIndexTupleExprSignature out_sig; + for (const auto& transform : + *indexes_transform_signature.out_signature.descriptors) { + const auto& converted = IndexesTransformApply(transform); + ADT_RETURN_IF_ERR(converted); + out_sig.descriptors->emplace_back(converted.GetOkValue()); + } + return OpIndexTupleExprSignature{in_sig, out_sig}; +} + +} // namespace + +adt::Result RecordableIndexClosureImpl::operator()( + const IndexTupleExpr& indexes_expr) const { + const auto& ApplyTransform = [&](const TrackedIndexesTransform& transform) { + return transform.Match( + [&](const adt::IdentityFunc&) -> adt::Result { + return indexes_expr; + }, + [&](const IndexTupleExpr& tacked_indexes_expr_as_func) + -> adt::Result { + return ValidIndexExprBuilder().Compose(tacked_indexes_expr_as_func, + indexes_expr); + }); + }; + return OpIndexesTransformApply(this->op_indexes_transform_signature, + ApplyTransform); +} + +adt::Result IndexClosure::operator()( + const IndexTupleExpr& indexes_expr) const { + return Match([&](const auto& impl) { return (*impl)(indexes_expr); }); +} + +} // namespace ap::index_expr diff --git a/paddle/ap/src/index_expr/index_expr_builtin_functions.cc b/paddle/ap/src/index_expr/index_expr_builtin_functions.cc new file mode 100644 index 00000000000000..eef5d8227608dc --- /dev/null +++ b/paddle/ap/src/index_expr/index_expr_builtin_functions.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ap/include/axpr/builtin_functions.h" +#include "paddle/ap/include/index_expr/index_expr_util.h" +#include "paddle/ap/include/index_expr/valid_index_expr_builder.h" +#include "paddle/ap/include/index_expr/value.h" +#include "paddle/ap/include/index_expr/value_method_class.h" + +namespace ap::index_expr {} // namespace ap::index_expr diff --git a/paddle/ap/src/index_expr/index_expr_util.cc b/paddle/ap/src/index_expr/index_expr_util.cc new file mode 100644 index 00000000000000..a8a788ba9d6a00 --- /dev/null +++ b/paddle/ap/src/index_expr/index_expr_util.cc @@ -0,0 +1,18 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/index_expr/index_expr_util.h" +#include "paddle/ap/include/axpr/adt.h" +#include "paddle/ap/include/index_expr/index_expr.h" +#include "paddle/ap/include/index_expr/index_tuple_expr.h" diff --git a/paddle/ap/src/index_expr/valid_index_expr_builder.cc b/paddle/ap/src/index_expr/valid_index_expr_builder.cc new file mode 100644 index 00000000000000..6d2f5188ca61cd --- /dev/null +++ b/paddle/ap/src/index_expr/valid_index_expr_builder.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/index_expr/valid_index_expr_builder.h" +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/index_expr/index_expr.h" +#include "paddle/ap/include/index_expr/index_expr_util.h" +#include "paddle/ap/include/index_expr/index_tuple_expr.h" +#include "paddle/ap/include/index_expr/slice.h" + +namespace ap::index_expr {} // namespace ap::index_expr diff --git a/paddle/ap/src/kernel_dispatch/device_ctx_method_class.cc b/paddle/ap/src/kernel_dispatch/device_ctx_method_class.cc new file mode 100644 index 00000000000000..0f39d6cd3e015a --- /dev/null +++ b/paddle/ap/src/kernel_dispatch/device_ctx_method_class.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/kernel_dispatch/device_ctx_method_class.h" +#include "paddle/ap/include/axpr/value.h" + +namespace ap::kernel_dispatch { + +struct DeviceCtxMethodClass { + using Self = DeviceCtx; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.__adt_rc_shared_ptr_raw_ptr(); + std::ostringstream ss; + ss << ""; + return ss.str(); + } + + static adt::Result GetStreamAddrAsVoidPtr( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_CHECK(args.size() == 0) << adt::errors::TypeError{ + std::string() + + "DeviceCtx.get_stream_addr_as_void_ptr() takes 0 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(void_ptr, self.shared_ptr()->GetStreamAddrAsVoidPtr()); + return void_ptr; + } +}; + +axpr::TypeImpl> GetDeviceCtxClass() { + using Impl = DeviceCtxMethodClass; + static auto cls( + axpr::MakeBuiltinClass("DeviceCtx", [&](const auto& Yield) { + Yield("__str__", &Impl::ToString); + Yield("get_stream_addr_as_void_ptr", &Impl::GetStreamAddrAsVoidPtr); + })); + using Self = typename Impl::Self; + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::kernel_dispatch diff --git a/paddle/ap/src/paddle/pass/ap_drr_helper.cc b/paddle/ap/src/paddle/pass/ap_drr_helper.cc new file mode 100644 index 00000000000000..bf551dc81a8fc5 --- /dev/null +++ b/paddle/ap/src/paddle/pass/ap_drr_helper.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/pass/ap_drr_helper.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/interpreter.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/drr/builtin_frame_util.h" +#include "paddle/ap/include/drr/drr_graph_descriptor.h" +#include "paddle/ap/include/drr/drr_interpreter.h" +#include "paddle/ap/include/drr/drr_node_descriptor.h" +#include "paddle/ap/include/drr/value_method_class.h" +#include "paddle/ap/include/paddle/pir/pir_method_class.h" + +namespace cinn::dialect::ir { + +namespace adt = ap::adt; + +namespace { + +using Function = ap::axpr::Value; + +using DrrNode = ap::drr::Node; +using DrrCtx = ap::drr::DrrCtx; + +} // namespace + +ApDrrHelper::ApDrrHelper( + const std::weak_ptr& circlable_ref_list) + : drr_interpreter_(ap::paddle::GetPirClass(), circlable_ref_list) {} + +adt::Result ApDrrHelper::InterpretDrrCtxMaker( + const Function& lambda, const std::vector& args) { + return drr_interpreter_.InterpretDrrCtxMaker(lambda, args); +} + +adt::Result ApDrrHelper::Interpret(const Function& lambda, + const std::string& drr_pass_name) { + return drr_interpreter_.InterpretPass(lambda, drr_pass_name); +} + +adt::Result ApDrrHelper::CreateDrrCtxByDrrPassObj( + const ap::axpr::Value& obj) { + return drr_interpreter_.CreateDrrCtxByDrrPassObj(obj); +} + +adt::Result ApDrrHelper::Interpret( + const ap::axpr::ClassAttrs& cls) { + return drr_interpreter_.InterpretPass(cls); +} + +} // namespace cinn::dialect::ir diff --git a/paddle/ap/src/paddle/pass/ap_kernel_define_helper.cc b/paddle/ap/src/paddle/pass/ap_kernel_define_helper.cc new file mode 100644 index 00000000000000..fa2b370313d4b4 --- /dev/null +++ b/paddle/ap/src/paddle/pass/ap_kernel_define_helper.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/pass/ap_kernel_define_helper.h" +#include "paddle/ap/include/axpr/interpreter.h" +#include "paddle/ap/include/code_gen/builtin_frame_util.h" +#include "paddle/ap/include/code_gen/value.h" +#include "paddle/ap/include/code_gen/value_method_class.h" +#include "paddle/ap/include/drr/drr_graph_descriptor.h" +#include "paddle/ap/include/drr/drr_node_descriptor.h" +#include "paddle/ap/include/paddle/op_cuda_code_gen_impl.h" +#include "paddle/ap/include/paddle/pir_node_method_class.h" + +namespace cinn::dialect::ir { + +namespace adt = ap::adt; + +namespace { + +using Function = ap::axpr::Value; +using CodeModule = ap::code_module::CodeModule; +using PirNode = ap::paddle::PirNode; +using Val = ap::code_gen::Value; +using CodeGenCtx = ap::code_gen::CodeGenCtx; +using CodeGenResult = ap::code_gen::CodeGenResult; + +} // namespace + +adt::Result ApKernelDefineHelper::Interpret( + const Function& lambda, const CodeGenCtx& code_gen_ctx) { + ap::axpr::BuiltinClassInstance code_gen_ctx_instance{ + ap::code_gen::GetCodeGenCtxClass(), code_gen_ctx}; + ap::axpr::Interpreter interpreter( + ap::code_gen::MakeBuiltinFrameAttrMap(), circlable_ref_list_); + ADT_CHECK(code_gen_ctx->ir_match_ctx.has_value()); + const auto& ir_match_ctx = code_gen_ctx->ir_match_ctx.value(); + ap::ir_match::OpMatchCtx op_match_ctx{ir_match_ctx.shared_ptr()}; + ap::axpr::BuiltinClassInstance op_match_ctx_instance{ + ap::ir_match::GetOpMatchCtxClass(), op_match_ctx}; + ap::ir_match::TensorMatchCtx tensor_match_ctx{ + ir_match_ctx.shared_ptr()}; + ap::axpr::BuiltinClassInstance tensor_match_ctx_instance{ + ap::ir_match::GetTensorMatchCtxClass(), + tensor_match_ctx}; + ADT_LET_CONST_REF(result, + interpreter.Interpret(lambda, + {code_gen_ctx_instance, + op_match_ctx_instance, + tensor_match_ctx_instance})); + ADT_LET_CONST_REF(m, result.template CastTo()); + return m; +} + +} // namespace cinn::dialect::ir diff --git a/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc b/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc new file mode 100644 index 00000000000000..df1e5cd717465c --- /dev/null +++ b/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc @@ -0,0 +1,3212 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/paddle/pass/ap_lower_fusion_op_pass.h" +#include "paddle/ap/include/memory/circlable_ref_list_base.h" + +#include "paddle/ap/include/adt/topo_walker.h" +#include "paddle/ap/include/axpr/abstract_list.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/atomic.h" +#include "paddle/ap/include/axpr/builtin_frame_util.h" +#include "paddle/ap/include/axpr/builtin_serializable_attr_map_to_axpr_helper.h" +#include "paddle/ap/include/axpr/cps_interpreter.h" +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/code_gen/arg_source_maker.h" +#include "paddle/ap/include/code_gen/matched_result_pattern_helper.h" +#include "paddle/ap/include/code_gen/value.h" +#include "paddle/ap/include/code_module/module_to_axpr_helper.h" +#include "paddle/ap/include/drr/drr_graph_descriptor.h" +#include "paddle/ap/include/drr/drr_node_descriptor.h" +#include "paddle/ap/include/drr/drr_pass_type_helper.h" +#include "paddle/ap/include/drr/res_ptn_packed_ir_op_declare_data.h" +#include "paddle/ap/include/drr/result_pattern_helper.h" +#include "paddle/ap/include/drr/value.h" +#include "paddle/ap/include/graph/graph_helper.h" +#include "paddle/ap/include/index_expr/valid_index_expr_builder.h" +#include "paddle/ap/include/ir_match/graph_matcher.h" +#include "paddle/ap/include/ir_match/ir_match_ctx.h" +#include "paddle/ap/include/ir_match/op_match_ctx_method_class.h" +#include "paddle/ap/include/ir_match/tensor_match_ctx_method_class.h" +#include "paddle/ap/include/paddle/indexed_ir_graph_util.h" +#include "paddle/ap/include/paddle/pass/ap_drr_helper.h" +#include "paddle/ap/include/paddle/pass/ap_kernel_define_helper.h" +#include "paddle/ap/include/paddle/pass/ap_registry_helper.h" +#include "paddle/ap/include/paddle/pass/ir_helper_method_class.h" +#include "paddle/ap/include/paddle/pir/manual_op.h" +#include "paddle/ap/include/paddle/pir/pir_method_class.h" +#include "paddle/ap/include/paddle/pir/pir_node_matched_src_ptn_ctx_helper.h" +#include "paddle/ap/include/paddle/pir/pir_to_anf_expr_helper.h" +#include "paddle/ap/include/paddle/pir/program_method_class.h" +#include "paddle/ap/include/paddle/pir_graph_descriptor.h" +#include "paddle/ap/include/paddle/pir_node.h" +#include "paddle/ap/include/paddle/pir_node_descriptor.h" +#include "paddle/ap/include/reified_drr/reified_drr_pass_dump_helper.h" +#include "paddle/ap/src/paddle/pass/op_factory.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/src/ir_operation_factory.h" +#include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace cinn::dialect::ir { + +namespace adt = ap::adt; + +namespace { + +using ap::paddle::PirNode; + +using DrrValue = ap::drr::Value; +using DrrNode = ap::drr::Node; + +using DrrCtx = ap::drr::DrrCtx; + +using DrrNativeIrValue = ap::drr::NativeIrValue; +using DrrPackedIrValue = ap::drr::PackedIrValue; +using DrrIrValue = ap::drr::IrValue; + +using DrrNativeIrOp = ap::drr::NativeIrOp; +using DrrNativeIrOpOperand = ap::drr::NativeIrOpOperand; +using DrrNativeIrOpResult = ap::drr::NativeIrOpResult; +using DrrPackedIrOp = ap::drr::PackedIrOp; +using DrrPackedIrOpOperand = ap::drr::PackedIrOpOperand; +using DrrPackedIrOpResult = ap::drr::PackedIrOpResult; +using DrrOptPackedIrOp = ap::drr::OptPackedIrOp; +using DrrOptPackedIrOpOperand = ap::drr::OptPackedIrOpOperand; +using DrrOptPackedIrOpResult = ap::drr::OptPackedIrOpResult; + +using DrrIrOpImpl = std::variant; + +using IrMatchCtx = ap::ir_match::IrMatchCtx; + +template +using NativeORGraph = + ap::graph::GraphDescriptor; + +template +using DefaultGraph = + ap::graph::GraphDescriptor; + +template +using RefAugmentedGraph = + ap::graph::GraphDescriptor; + +using ap::axpr::AnfExpr; +using CGValue = ap::code_gen::Value; +using CodeGenCtx = ap::code_gen::CodeGenCtx; +using CodeGenResult = ap::code_gen::CodeGenResult; +using ap::code_module::CodeModule; + +struct DrrIrOp : public DrrIrOpImpl { + using DrrIrOpImpl::DrrIrOpImpl; + ADT_DEFINE_VARIANT_METHODS(DrrIrOpImpl); + + const std::string& op_name() const { + return Match([](const auto& impl) -> const std::string& { + return impl->op_declare->op_name; + }); + } +}; +using DrrGraphNode = ap::graph::Node; +using GraphMatchCtx = ap::ir_match::GraphMatchCtx; + +using PirNativeIrValue = ap::paddle::NativeIrValue; +using PirNativeIrOpOperand = ap::paddle::NativeIrOpOperand; +using PirNativeIrOpResult = ap::paddle::NativeIrOpResult; + +adt::Result GetApDrrDefaultAnchor(const DrrCtx& drr_ctx) { + ADT_LET_CONST_REF(src_ptn_ctx, drr_ctx->GetSourcePatternCtx()); + auto ptn_node_area = src_ptn_ctx->node_arena; + ap::graph::GraphDescriptor + source_pattern_graph{}; + ADT_CHECK(ptn_node_area->nodes().size() > 0); + ap::graph::GraphHelper + graph_helper(source_pattern_graph); + const auto& start_ptn_node = ptn_node_area->nodes().at(0).node(); + ADT_LET_CONST_REF(anchor_node, graph_helper.FindAnchor(start_ptn_node)); + ADT_LET_CONST_REF(default_anchor, anchor_node.Get()); + return default_anchor; +} + +adt::Result> GetApDrrNativeIrOpAnchor( + const DrrCtx& drr_ctx) { + ADT_LET_CONST_REF(src_ptn_ctx, drr_ctx->GetSourcePatternCtx()); + auto ptn_node_area = src_ptn_ctx->node_arena; + ap::graph::GraphDescriptor + source_pattern_graph{}; + ADT_CHECK(ptn_node_area->nodes().size() > 0); + ap::graph::GraphHelper + graph_helper(source_pattern_graph); + const auto& start_ptn_node = ptn_node_area->nodes().at(0).node(); + auto IsNativeOpWithOutputs = [&](const auto& node) -> adt::Result { + ADT_LET_CONST_REF(drr_node, node.Get()); + ADT_LET_CONST_REF(downstreams, node.DownstreamNodes()); + return drr_node.template Has() && downstreams.size() > 0; + }; + const auto& Filter = IsNativeOpWithOutputs; + ADT_LET_CONST_REF(anchor_node, + graph_helper.FilterAnchor(start_ptn_node, Filter)); + if (!anchor_node.has_value()) { + return std::nullopt; + } + ADT_LET_CONST_REF(anchor, anchor_node.value().Get()); + ADT_LET_CONST_REF(native_ir_op, anchor.template TryGet()); + return native_ir_op; +} + +adt::Result> GetResPtnOutputs(const DrrCtx& drr_ctx) { + std::vector ret; + ADT_LET_CONST_REF(res_ptn_ctx, drr_ctx->GetResultPatternCtx()); + const auto& nodes = res_ptn_ctx->node_arena->nodes(); + for (const auto& drr_node : nodes) { + ADT_LET_CONST_REF(downstreams, drr_node.node().DownstreamNodes()); + if (downstreams.size() == 0) { + const auto& opt_drr_ir_value = DrrIrValue::OptCastFrom(drr_node); + if (opt_drr_ir_value.has_value()) { + ret.push_back(opt_drr_ir_value.value()); + } + } + } + return ret; +} + +class DrrCtxProvider { + public: + DrrCtxProvider() {} + + virtual adt::Result> GetDrrCtxList() = 0; + + virtual adt::Result PostProcess( + adt::Result> (*Match)(const DrrCtx&, + pir::Operation* op), + const DrrCtx& drr_ctx, + pir::Operation* op, + const GraphMatchCtx& match_ctx, + const std::function(const std::string&)>& + CodeGenResult4FusedOpName) = 0; +}; + +class NaiveDrrCtxProvider : public DrrCtxProvider { + DrrCtx drr_ctx_; + + public: + explicit NaiveDrrCtxProvider(const DrrCtx& drr_ctx) : drr_ctx_(drr_ctx) {} + + adt::Result> GetDrrCtxList() override { + return adt::List{drr_ctx_}; + } + + adt::Result PostProcess( + adt::Result> (*Match)(const DrrCtx&, + pir::Operation* op), + const DrrCtx& drr_ctx, + pir::Operation* op, + const GraphMatchCtx& match_ctx, + const std::function(const std::string&)>& + CodeGenResult4FusedOpName) override { + // Do Nothing. + return adt::Ok{}; + } +}; + +struct ApLowerFusionOpPatternCtx { + std::shared_ptr drr_ctx_provider_; + DrrCtx drr_ctx; + std::vector res_ptn_outputs; + DrrNode default_anchor; + std::optional native_op_anchor; + std::string anchor_op_name; + std::optional steps_limit; + + const std::shared_ptr& drr_ctx_provider() const { + return drr_ctx_provider_; + } + + static adt::Result MakeFromDrrCtx( + const DrrCtx& drr_ctx, + std::optional steps_limit, + const std::shared_ptr& drr_ctx_provider) { + ADT_LET_CONST_REF(res_ptn_outputs, GetResPtnOutputs(drr_ctx)); + ADT_LET_CONST_REF(default_anchor, GetApDrrDefaultAnchor(drr_ctx)); + std::optional opt_native_op_anchor; + if (ap::drr::DrrPassTypeHelper{}.SupportOptionalPackedOp( + drr_ctx->drr_pass_type)) { + ADT_LET_CONST_REF(opt_native_ir_op_anchor, + GetApDrrNativeIrOpAnchor(drr_ctx)); + opt_native_op_anchor = opt_native_ir_op_anchor; + } else { + opt_native_op_anchor = std::nullopt; + } + ADT_LET_CONST_REF(anchor_op_name, + GetAnchorOpName(opt_native_op_anchor, default_anchor)); + return ApLowerFusionOpPatternCtx{drr_ctx_provider, + drr_ctx, + res_ptn_outputs, + default_anchor, + opt_native_op_anchor, + anchor_op_name, + steps_limit}; + } + + static adt::Result GetAnchorOpName( + const std::optional& native_op_anchor, + const DrrNode& default_anchor) { + if (native_op_anchor.has_value()) { + return native_op_anchor.value()->op_declare->op_name; + } + return default_anchor.Match( + [&](const DrrNativeIrOp& ir_op) -> adt::Result { + return ir_op->op_declare->op_name; + }, + [&](const DrrPackedIrOp& ir_op) -> adt::Result { + return PirNode::GetOpNameFromDrrPackedOpName( + ir_op->op_declare->op_name); + }, + [&](const DrrOptPackedIrOp& ir_op) -> adt::Result { + return PirNode::GetOpNameFromDrrPackedOpName( + ir_op->op_declare->op_name); + }, + [&](const auto&) -> adt::Result { + return adt::errors::TypeError{ + "default_anchor drr node should be a op node but value node " + "found."}; + }); + } +}; + +struct ApRewriter { + ApLowerFusionOpPatternCtx ctx_; + adt::Result> (*Match_)(const DrrCtx&, + pir::Operation* op); + mutable ApDrrHelper ap_drr_helper_; + + ApRewriter(const ApLowerFusionOpPatternCtx& ctx, + adt::Result> (*Match)( + const DrrCtx&, pir::Operation* op)) + : ctx_(ctx), + Match_(Match), + ap_drr_helper_(ctx_.drr_ctx->circlable_ref_list) {} + + adt::Result Rewrite(const GraphMatchCtx& match_ctx, + pir::Operation* op, + pir::PatternRewriter* rewriter) const { + ADT_CHECK(ctx_.drr_ctx->pass_name.has_value()); + LOG(ERROR) << "drr: " << ctx_.drr_ctx->pass_name.value() << " matched."; + return RewriteByResultPattern(match_ctx, op, rewriter); + } + + private: + adt::Result RewriteByResultPattern( + const GraphMatchCtx& match_ctx, + pir::Operation* op, + pir::PatternRewriter* rewriter) const { + ADT_LET_CONST_REF(rewritten, + TryRewriteByResultPattern(match_ctx, op, rewriter)); + return rewritten; + } + + using IrValue2UseIterators = + std::unordered_map>; + + struct RewriteCtx { + std::unordered_map matched_op2order_value; + IrValue2UseIterators output2original_uses; + std::unordered_map name2native_value; + std::unordered_map> name2packed_values; + + adt::Result GetMatchedOpOrderValue(pir::Operation* op) const { + const auto iter = this->matched_op2order_value.find(op); + if (iter == this->matched_op2order_value.end()) { + return adt::errors::IndexError{ + "RewriteCtx::GetMatchedOpOrderValue failed."}; + } + return iter->second; + } + + adt::Result GetNativeIrValue( + const std::string& ir_value_name) const { + const auto iter = this->name2native_value.find(ir_value_name); + if (iter == this->name2native_value.end()) { + return adt::errors::IndexError{ + "RewriteCtx::GetNativeIrValue() failed. key '" + ir_value_name + + "' not found."}; + } + return iter->second; + } + + adt::Result*> GetPackedIrValues( + const std::string& ir_value_name) const { + const auto iter = this->name2packed_values.find(ir_value_name); + if (iter == this->name2packed_values.end()) { + return adt::errors::IndexError{ + "RewriteCtx::GetPackedIrValues() failed. key '" + ir_value_name + + "' not found"}; + } + return &iter->second; + } + }; + + adt::Result> GetMatchedOps( + const GraphMatchCtx& match_ctx) const { + using DefaultDrrGraph = + ap::graph::GraphDescriptor; + DefaultDrrGraph default_drr_graph{}; + ADT_LET_CONST_REF(src_ptn_ctx, ctx_.drr_ctx->GetSourcePatternCtx()); + const auto& nodes = src_ptn_ctx->node_arena->nodes(); + std::unordered_set ops; + for (const auto& drr_node : nodes) { + ADT_LET_CONST_REF(is_op_node, + default_drr_graph.IsOpNode(drr_node.node())); + if (!is_op_node) { + continue; + } + ADT_LET_CONST_REF(pir_node, + match_ctx->GetSoleBigGraphNode(drr_node.node())); + const auto& opt_op = CastToPirOp(pir_node); + if (opt_op.has_value()) { + ADT_CHECK(ops.emplace(opt_op.value()).second); + } + } + return ops; + } + + std::optional CastToPirOp(const PirNode& pir_node) const { + return pir_node.Match( + [](const ap::paddle::NativeIrOp& ir_op) + -> std::optional { return ir_op.op; }, + [&](const ap::paddle::PackedIrOp& ir_op) + -> std::optional { + return static_cast(ir_op.fusion_op); + }, + [&](const auto&) -> std::optional { + return std::nullopt; + }); + } + + adt::Result InitRewriteCtx( + RewriteCtx* rewrite_ctx, const GraphMatchCtx& graph_match_ctx) const { + ADT_LET_CONST_REF(matched_op2order_value_map, + MakeMatchedOp2OrderValue(graph_match_ctx)); + rewrite_ctx->matched_op2order_value = matched_op2order_value_map; + auto* map = &rewrite_ctx->output2original_uses; + ADT_RETURN_IF_ERR(InitResPtnOutput2UseIterators(map, graph_match_ctx)); + return adt::Ok{}; + } + + adt::Result InitResPtnOutput2UseIterators( + IrValue2UseIterators* map, const GraphMatchCtx& graph_match_ctx) const { + auto UpdateValue2Use = [&](pir::Value output) -> adt::Result { + auto* lst = &(*map)[output]; + for (auto iter = output.use_begin(); iter != output.use_end(); ++iter) { + lst->emplace_back(iter); + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + VisitResPtnOutputPirValue(graph_match_ctx, UpdateValue2Use)); + return adt::Ok{}; + } + + adt::Result> + MakeMatchedOp2OrderValue(const GraphMatchCtx& match_ctx) const { + ADT_LET_CONST_REF(ops, GetMatchedOps(match_ctx)); + return MakeMatchedOp2OrderValue(ops); + } + + adt::Result> + MakeMatchedOp2OrderValue( + const std::unordered_set& ops) const { + std::unordered_map ret; + if (ops.empty()) { + return ret; + } + pir::Operation* start = *ops.begin(); + auto* block = start->GetParent(); + pir::Block::Iterator left_iter = *start; + pir::Block::Iterator right_iter = *start; + for (int i = 0; ret.size() < ops.size() && i < block->size(); ++i) { + if (ops.count(&*left_iter) > 0) { + ret[&*left_iter] = -i; + } + if (ops.count(&*right_iter) > 0) { + ret[&*right_iter] = i; + } + if (&*left_iter != &block->front()) { + --left_iter; + } + if (&*right_iter != &block->back()) { + ++right_iter; + } + } + ADT_CHECK(ret.size() == ops.size()); + return ret; + } + + using CodeGenResultCollectT = std::function( + const std::string& fused_op_name, const CodeGenResult&)>; + + adt::Result TryRewriteByResultPattern( + const GraphMatchCtx& match_ctx, + pir::Operation* op, + pir::PatternRewriter* rewriter) const { + ADT_RETURN_IF_ERR(WithPostProcessGuard( + match_ctx, + op, + [&](const auto& CodeGenResultCollect) -> adt::Result { + RewriteCtx rewrite_ctx; + ADT_RETURN_IF_ERR(InitRewriteCtx(&rewrite_ctx, match_ctx)); + auto Build = [&](const auto& res_ptn_op) -> adt::Result { + return BuildNewOp(rewriter, + res_ptn_op, + &rewrite_ctx, + match_ctx, + CodeGenResultCollect); + }; + ADT_RETURN_IF_ERR(VisitEachResPtnOp(Build)); + ADT_RETURN_IF_ERR( + ReplaceOutputResPtnTensor(match_ctx, rewrite_ctx, rewriter)); + return adt::Ok{}; + })); + return true; + } + + template + adt::Result WithPostProcessGuard( + const GraphMatchCtx& match_ctx, + pir::Operation* op, + const DoWithCollectorT& DoWithCollector) const { + std::map fused_op_name2code_gen_result; + auto CodeGenResultCollect = + [&](const std::string& fused_op_name, + const CodeGenResult& code_gen_result) -> adt::Result { + ADT_CHECK( + fused_op_name2code_gen_result.emplace(fused_op_name, code_gen_result) + .second); + return adt::Ok{}; + }; + + ADT_RETURN_IF_ERR(DoWithCollector(CodeGenResultCollect)); + + using RetT = adt::Result; + auto CodeGenResult4FusedOpName = + [&](const std::string& fused_op_name) -> RetT { + const auto& iter = fused_op_name2code_gen_result.find(fused_op_name); + ADT_CHECK(iter != fused_op_name2code_gen_result.end()); + return iter->second; + }; + ADT_RETURN_IF_ERR(ctx_.drr_ctx_provider()->PostProcess( + Match_, ctx_.drr_ctx, op, match_ctx, CodeGenResult4FusedOpName)); + return adt::Ok{}; + } + + adt::Result ReplaceOutputResPtnTensor( + const GraphMatchCtx& match_ctx, + const RewriteCtx& rewrite_ctx, + pir::PatternRewriter* rewriter) const { + auto Replace = [&](pir::Value from, pir::Value to) -> adt::Result { + // Reason for no use of `rewriter->ReplaceAllUsesWith(from, to)`: + // AP drr mechanism support result pattern like: + // o.foo_op( + // [o.bar_value], + // [o.bar_value] + // ) + // It will insert `foo_op` between pir::Value named `bar_value` and its + // consumer ops except the newly inserted `foo_op`. + auto iter = rewrite_ctx.output2original_uses.find(from); + ADT_CHECK(iter != rewrite_ctx.output2original_uses.end()); + for (auto use_iter : iter->second) { + use_iter->set_source(to); + } + return adt::Ok{}; + }; + return VisitOutputPirValueReplacementPair(match_ctx, rewrite_ctx, Replace); + } + + template + adt::Result VisitResPtnOutputPirValue(const GraphMatchCtx& match_ctx, + const YieldT& Yield) const { + for (const auto& res_ptn_drr_ir_value : ctx_.res_ptn_outputs) { + const auto& opt_drr_ir_value = + SrcPtnIrValue4ResPtnIrValue(res_ptn_drr_ir_value); + ADT_CHECK(opt_drr_ir_value.has_value()); + const auto& drr_ir_value = opt_drr_ir_value.value(); + const auto& ret = drr_ir_value.Match( + [&](const DrrNativeIrValue& native_ir_value) -> adt::Result { + ADT_LET_CONST_REF( + pir_node, + match_ctx->GetSoleBigGraphNode(native_ir_value->node)); + ADT_LET_CONST_REF( + pir_value, + pir_node.template TryGet()); + return Yield(pir_value.value); + }, + [&](const DrrPackedIrValue& packed_ir_value) -> adt::Result { + ADT_LET_CONST_REF(from_nodes, + match_ctx->GetPackedBigGraphIrValueNodes( + packed_ir_value->node)); + for (int i = 0; i < from_nodes->size(); ++i) { + const auto& from_node = from_nodes->at(i); + ADT_LET_CONST_REF( + pir_value, + from_node.template TryGet()); + ADT_RETURN_IF_ERR(Yield(pir_value.value)); + } + return adt::Ok{}; + }); + ADT_RETURN_IF_ERR(ret); + } + return adt::Ok{}; + } + + template + adt::Result VisitOutputPirValueReplacementPair( + const GraphMatchCtx& match_ctx, + const RewriteCtx& rewrite_ctx, + const DoEachPairT& DoEachPair) const { + for (const auto& res_ptn_drr_ir_value : ctx_.res_ptn_outputs) { + const auto& opt_drr_ir_value = + SrcPtnIrValue4ResPtnIrValue(res_ptn_drr_ir_value); + ADT_CHECK(opt_drr_ir_value.has_value()); + const auto& drr_ir_value = opt_drr_ir_value.value(); + const auto& ret = drr_ir_value.Match( + [&](const DrrNativeIrValue& native_ir_value) -> adt::Result { + ADT_LET_CONST_REF( + pir_node, + match_ctx->GetSoleBigGraphNode(native_ir_value->node)); + ADT_LET_CONST_REF( + pir_value, + pir_node.template TryGet()); + pir::Value from = pir_value.value; + ADT_LET_CONST_REF( + to, rewrite_ctx.GetNativeIrValue(native_ir_value->name)); + return DoEachPair(from, to); + }, + [&](const DrrPackedIrValue& packed_ir_value) -> adt::Result { + ADT_LET_CONST_REF(from_nodes, + match_ctx->GetPackedBigGraphIrValueNodes( + packed_ir_value->node)); + ADT_LET_CONST_REF( + to_values_ptr, + rewrite_ctx.GetPackedIrValues(packed_ir_value->name)); + ADT_CHECK(from_nodes->size() == to_values_ptr->size()) + << adt::errors::ValueError{ + "from_nodes->size(): " + + std::to_string(from_nodes->size()) + + ", to_values_ptr->size(): " + + std::to_string(to_values_ptr->size()) + "."}; + for (int i = 0; i < from_nodes->size(); ++i) { + const auto& from_node = from_nodes->at(i); + ADT_LET_CONST_REF( + pir_value, + from_node.template TryGet()); + pir::Value from = pir_value.value; + pir::Value to = to_values_ptr->at(i); + ADT_RETURN_IF_ERR(DoEachPair(from, to)); + } + return adt::Ok{}; + }); + ADT_RETURN_IF_ERR(ret); + } + return adt::Ok{}; + } + + template + adt::Result VisitEachResPtnOp(const YieldT& Yield) const { + auto DoEachResPtnOp = + [&](const auto& res_ptn_graph_node) -> adt::Result { + ADT_LET_CONST_REF(res_ptn_node, res_ptn_graph_node.Get()); + const auto& opt_res_ptn_op = ConvertToResPtnOp(res_ptn_node); + if (opt_res_ptn_op.has_value()) { + ADT_RETURN_IF_ERR(Yield(opt_res_ptn_op.value())); + } + return adt::Ok{}; + }; + return VisitEachResPtnGraphNode(DoEachResPtnOp); + } + + template + adt::Result VisitEachResPtnGraphNode(const YieldT& Yield) const { + ADT_LET_CONST_REF(res_ptn_ctx, ctx_.drr_ctx->GetResultPatternCtx()); + std::list sources; + for (const auto& drr_node : res_ptn_ctx->node_arena->nodes()) { + const auto& drr_graph_node = drr_node.node(); + ADT_LET_CONST_REF(upstreams, drr_graph_node.UpstreamNodes()); + if (upstreams.size() == 0) { + sources.push_back(drr_graph_node); + } + } + using Ok = adt::Result; + ap::drr::DefaultDrrGraphDescriptor graph{}; + auto VisitPrev = [&](const DrrGraphNode& node, const auto& Yield) -> Ok { + return graph.VisitUpstreamNodes(node, Yield); + }; + auto VisitNext = [&](const DrrGraphNode& node, const auto& Yield) -> Ok { + return graph.VisitDownstreamNodes(node, Yield); + }; + ap::adt::TopoWalker walker{VisitPrev, VisitNext}; + ADT_RETURN_IF_ERR(walker(sources.begin(), sources.end(), Yield)); + return adt::Ok{}; + } + + std::optional ConvertToResPtnOp(const DrrNode& drr_node) const { + return drr_node.Match( + [&](const DrrNativeIrOp& ir_op) -> std::optional { + return DrrIrOp{ir_op}; + }, + [&](const DrrPackedIrOp& ir_op) -> std::optional { + return DrrIrOp{ir_op}; + }, + [&](const auto&) -> std::optional { return std::nullopt; }); + } + + adt::Result BuildNewOp( + pir::PatternRewriter* rewriter, + const DrrIrOp& res_ptn_op, + RewriteCtx* rewrite_ctx, + const GraphMatchCtx& match_ctx, + const CodeGenResultCollectT& CodeGenResultCollect) const { + return res_ptn_op.Match( + [&](const DrrNativeIrOp& ir_op) -> adt::Result { + return BuildNativeOp( + rewriter, ir_op, rewrite_ctx, match_ctx, CodeGenResultCollect); + }, + [&](const DrrPackedIrOp& ir_op) -> adt::Result { + return BuildPackedOp( + rewriter, ir_op, rewrite_ctx, match_ctx, CodeGenResultCollect); + }); + } + + adt::Result BuildNativeOp( + pir::PatternRewriter* rewriter, + const DrrNativeIrOp& res_ptn_ir_op, + RewriteCtx* rewrite_ctx, + const GraphMatchCtx& match_ctx, + const CodeGenResultCollectT& CodeGenResultCollect) const { + ADT_RETURN_IF_ERR( + InsertInputPirValueToReplaceCtx(res_ptn_ir_op, rewrite_ctx, match_ctx)); + ADT_LET_CONST_REF(input_values, + GetNativeOpInputValues(res_ptn_ir_op, *rewrite_ctx)); + ADT_RETURN_IF_ERR( + TrySetInsertPointer(rewriter, *rewrite_ctx, input_values, match_ctx)); + ADT_LET_CONST_REF(attributes, + GetResPtnOpAttributes(res_ptn_ir_op, match_ctx)); + ADT_LET_CONST_REF( + output_values, + ConstructNativeOp(rewriter, res_ptn_ir_op, input_values, attributes)); + ADT_RETURN_IF_ERR(UpdateConstructedOpOutputsInReplaceCtx( + match_ctx, output_values, res_ptn_ir_op, rewrite_ctx)); + return adt::Ok{}; + } + + adt::Result GetResPtnOpAttributes( + const DrrNativeIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx) const { + ADT_CHECK(ctx_.drr_ctx->source_pattern_ctx.has_value()); + IrMatchCtx ir_match_ctx{ctx_.drr_ctx->source_pattern_ctx.value(), + match_ctx}; + const auto& args = GetResPtnAttrGetterArgs(ir_match_ctx); + pir::AttributeMap attrs; + auto* drr_interpreter = ap_drr_helper_.mut_drr_interpreter(); + using Ok = adt::Result; + auto CollectAttr = [&](const auto& attr_name, const auto& getter) -> Ok { + ADT_LET_CONST_REF(attr_val, drr_interpreter->Interpret(getter, args)); + ADT_CHECK(ctx_.drr_ctx->pass_name.has_value()); + ADT_LET_CONST_REF(attr, attr_val.template CastTo()); + ADT_CHECK(attrs.emplace(attr_name, attr).second); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitResPtnOpAttr(res_ptn_ir_op, CollectAttr)); + return attrs; + } + + std::vector GetResPtnAttrGetterArgs( + const IrMatchCtx& ir_match_ctx) const { + ap::ir_match::OpMatchCtx op_match_ctx{ir_match_ctx.shared_ptr()}; + ap::ir_match::TensorMatchCtx tensor_match_ctx{ + ir_match_ctx.shared_ptr()}; + return std::vector{ + ap::ir_match::GetOpMatchCtxClass().New( + op_match_ctx), + ap::ir_match::GetTensorMatchCtxClass().New( + tensor_match_ctx), + }; + } + + template + adt::Result VisitResPtnOpAttr(const DrrNativeIrOp& res_ptn_ir_op, + const YieldT& Yield) const { + const auto& attr_map = res_ptn_ir_op->op_declare->attr_map; + for (const auto& [attr_name, getter] : attr_map->storage) { + ADT_RETURN_IF_ERR(Yield(attr_name, getter)); + } + return adt::Ok{}; + } + + adt::Result> ConstructNativeOp( + pir::PatternRewriter* rewriter, + const DrrNativeIrOp& res_ptn_ir_op, + const std::vector& inputs, + const pir::AttributeMap& attrs) const { + { + ADT_LET_CONST_REF( + opt_op, + ap::paddle::CreateOperation( + rewriter, res_ptn_ir_op->op_declare->op_name, inputs, attrs)); + if (opt_op.has_value()) { + return opt_op.value()->results(); + } + } + try { + pir::Operation* op = + paddle::drr::OperationFactory::Instance().CreateOperation( + res_ptn_ir_op->op_declare->op_name, inputs, attrs, *rewriter); + return op->results(); + } catch (const std::exception& e) { + return adt::errors::ValueError{ + std::string() + + "OperationFactory::Instance().CreateOperation() failed. op_name: " + + res_ptn_ir_op->op_declare->op_name + ". what(): " + e.what()}; + } + } + + adt::Result BuildPackedOp( + pir::PatternRewriter* rewriter, + const DrrPackedIrOp& res_ptn_ir_op, + RewriteCtx* rewrite_ctx, + const GraphMatchCtx& match_ctx, + const CodeGenResultCollectT& CodeGenResultCollect) const { + ADT_CHECK(res_ptn_ir_op->op_declare->op_name, "ap_pattern_fusion_op"); + ADT_RETURN_IF_ERR( + InsertInputPirValueToReplaceCtx(res_ptn_ir_op, rewrite_ctx, match_ctx)); + ADT_LET_CONST_REF(input_values, + GetPackedOpInputValues(res_ptn_ir_op, *rewrite_ctx)); + ADT_RETURN_IF_ERR( + TrySetInsertPointer(rewriter, *rewrite_ctx, input_values, match_ctx)); + ADT_LET_CONST_REF(combined_value, InsertCombinedOp(rewriter, input_values)); + ADT_LET_CONST_REF(code_gen_result, CodeGen(res_ptn_ir_op, match_ctx)); + ADT_RETURN_IF_ERR( + CodeGenResultCollect(res_ptn_ir_op->name, code_gen_result)); + ADT_LET_CONST_REF( + code_module_anf_expr, + ConvertApKernelModuleToAnfExpr(code_gen_result->code_module)); + const auto& code_gen_lambda_str = code_module_anf_expr.DumpToJsonString(); + const auto& kernel_dispatch_func = code_gen_result->kernel_dispatch_func; + const auto& kernel_dispatch_const_data = + code_gen_result->kernel_dispatch_const_data; + ADT_LET_CONST_REF(infer_meta_lambda_str, + GetInferMetaLambdaStr(res_ptn_ir_op, match_ctx)); + ADT_LET_CONST_REF(kernel_dispatch_lambda_str, + GetKernelDispatchLambdaStr(kernel_dispatch_func)); + ADT_LET_CONST_REF( + kernel_dispatch_const_data_lambda_str, + GetKernelDispatchConstDataLambdaStr( + res_ptn_ir_op, match_ctx, kernel_dispatch_const_data)); + ADT_LET_CONST_REF(num_outputs, + GetApKernelNumOutputs(res_ptn_ir_op, match_ctx)); + ADT_LET_CONST_REF( + ap_pattern_fusion_combined_out, + MakeApPatternFusionOp(rewriter, + combined_value, + num_outputs, + code_gen_lambda_str, + infer_meta_lambda_str, + kernel_dispatch_lambda_str, + kernel_dispatch_const_data_lambda_str)); + ADT_LET_CONST_REF( + output_values, + GetPackedOpOutputValues(rewriter, ap_pattern_fusion_combined_out)); + ADT_RETURN_IF_ERR(UpdateConstructedOpOutputsInReplaceCtx( + match_ctx, output_values, res_ptn_ir_op, rewrite_ctx)); + return adt::Ok{}; + } + + struct InputDimIndex { + int input_idx; + int tensor_axis; + }; + + struct OpInferMetaCtx { + std::unordered_map dim_expr2in_dim_index; + mutable std::unordered_map dim_expr2anf_expr; + }; + + struct TensorMeta { + std::vector shape; + pir::Type dtype; + }; + + adt::Result GetInferMetaLambdaStr( + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx) const { + ADT_LET_CONST_REF(infer_meta_ctx, + GetOpInferMetaCtx(res_ptn_ir_op, match_ctx)); + ADT_LET_CONST_REF(outputs, GetOpOutputPirValues(res_ptn_ir_op, match_ctx)); + auto ConstructLambdaBody = + [&](ap::axpr::LetContext& ctx) -> adt::Result { + for (int i = 0; i < outputs.size(); ++i) { + const auto& output = outputs.at(i); + auto& output_meta_var = ctx.Var("outputs").At(i); + ADT_LET_CONST_REF(dim_exprs_ptr, GetShapeDimExprsPtrByValue(output)); + ADT_LET_CONST_REF(ddim_val, + ConstructDDims(&ctx, infer_meta_ctx, *dim_exprs_ptr)); + output_meta_var.SetAttr("dims", ddim_val); + ADT_LET_CONST_REF(dtype, GetPirDataType(output)); + ADT_LET_CONST_REF(dtype_val, + ConstructDtype(&ctx, infer_meta_ctx, dtype)); + output_meta_var.SetAttr("dtype", dtype_val); + } + return ctx.None(); + }; + ap::axpr::LambdaExprBuilder lmbd; + ADT_LET_CONST_REF( + anf_expr, lmbd.TryLambda({"inputs", "outputs"}, ConstructLambdaBody)); + return anf_expr.DumpToJsonString(); + } + + adt::Result GetPirDataType(pir::Value value) const { + if (!value.type().isa()) { + return adt::errors::NotImplementedError{ + "pir value must be of DenseTensorType"}; + } + const auto dense_tensor_type = + value.type().dyn_cast(); + return dense_tensor_type.dtype(); + } + + adt::Result> GetOpOutputPirValues( + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx) const { + return GetMatchedPirOutputsOfRestPtnPackedIrOp(res_ptn_ir_op, match_ctx); + } + + adt::Result GetOpInferMetaCtx( + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx) const { + ADT_LET_CONST_REF( + inputs, + GetMatchedPirInputsOfRestPtnPackedIrOp(res_ptn_ir_op, match_ctx)); + OpInferMetaCtx infer_meta_ctx{}; + auto* map = &infer_meta_ctx.dim_expr2in_dim_index; + for (int in_idx = 0; in_idx < inputs.size(); ++in_idx) { + pir::Value input = inputs.at(in_idx); + ADT_LET_CONST_REF(dim_exprs_ptr, GetShapeDimExprsPtrByValue(input)); + for (int tensor_axis = 0; tensor_axis < dim_exprs_ptr->size(); + ++tensor_axis) { + const auto& dim_expr = dim_exprs_ptr->at(tensor_axis); + map->emplace(dim_expr, InputDimIndex{in_idx, tensor_axis}); + } + } + return infer_meta_ctx; + } + + adt::Result*> GetShapeDimExprsPtrByValue( + pir::Value value) const { + auto* op = value.defining_op(); + ADT_CHECK(op != nullptr); + auto* program = op->GetParentProgram(); + auto& shape_analysis = ::pir::ShapeAnalysisManager::Instance().Get(program); + const auto& shape_or_data = shape_analysis.GetShapeOrDataForValue(value); + using RetT = adt::Result*>; + return shape_or_data.Match( + [&](const symbol::TensorShapeOrDataDimExprs& impl) -> RetT { + return &impl.shape(); + }, + [&](const auto&) -> RetT { + return adt::errors::TypeError{ + "GetShapeDimExprsPtrByValue only support " + "TensorShapeOrDataDimExprs."}; + }); + } + + adt::Result ConstructDtype(ap::axpr::LetContext* ctx, + const OpInferMetaCtx& infer_meta_ctx, + pir::Type type) const { + try { + ::phi::DataType phi_dtype = ::paddle::dialect::TransToPhiDataType(type); + ADT_LET_CONST_REF(dtype, ap::axpr::GetDataTypeFromPhiDataType(phi_dtype)); + return static_cast(ctx->Var("DataType").Attr(dtype.Name())); + } catch (const std::exception& e) { + return adt::errors::TypeError{ + "failed to cast from pir data type to phi data type."}; + } + } + + adt::Result ConstructDDims( + ap::axpr::LetContext* ctx, + const OpInferMetaCtx& infer_meta_ctx, + const std::vector& dim_exprs) const { + std::vector anf_dims; + for (const auto& dim_expr : dim_exprs) { + ADT_LET_CONST_REF(anf_dim_expr, + ConstructDDimDimExpr(ctx, infer_meta_ctx, dim_expr)); + anf_dims.emplace_back(anf_dim_expr); + } + return ctx->Call(ap::axpr::kBuiltinList(), anf_dims); + } + + adt::Result ConstructDDimDimExpr( + ap::axpr::LetContext* ctx, + const OpInferMetaCtx& infer_meta_ctx, + const symbol::DimExpr& dim_expr) const { + return dim_expr.Match( + [&](int64_t c) -> adt::Result { return ctx->Int64(c); }, + [&](const auto&) -> adt::Result { + return ConstructDDimDimExprByInputs(ctx, infer_meta_ctx, dim_expr); + }); + } + + adt::Result ConstructDDimDimExprByInputs( + ap::axpr::LetContext* ctx, + const OpInferMetaCtx& infer_meta_ctx, + const symbol::DimExpr& dim_expr) const { + const auto& idx_iter = infer_meta_ctx.dim_expr2in_dim_index.find(dim_expr); + ADT_CHECK(idx_iter != infer_meta_ctx.dim_expr2in_dim_index.end()); + auto anf_expr_iter = infer_meta_ctx.dim_expr2anf_expr.find(dim_expr); + if (anf_expr_iter == infer_meta_ctx.dim_expr2anf_expr.end()) { + const auto& in_dim = ConstructInDimExpr(ctx, idx_iter->second); + anf_expr_iter = + infer_meta_ctx.dim_expr2anf_expr.emplace(dim_expr, in_dim).first; + } + return anf_expr_iter->second; + } + + AnfExpr ConstructInDimExpr(ap::axpr::LetContext* ctx, + const InputDimIndex& idx) const { + return static_cast( + ctx->Var("inputs").At(idx.input_idx).Attr("dims").At(idx.tensor_axis)); + } + + adt::Result GetCodeFromBuiltinSerializableAttrMap( + ap::axpr::LetContext* ctx, + const ap::axpr::AttrMap& attr_map) const { + return ap::axpr::BuiltinSerializableAttrMapToAxprHelper{}.Convert(ctx, + attr_map); + } + + adt::Result GetKernelDispatchConstDataLambdaStr( + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx, + const ap::axpr::AttrMap& + kernel_dispatch_const_data) const { + ap::axpr::LambdaExprBuilder lmbd; + auto ConstructLambdaBody = [&](auto& ctx) -> adt::Result { + ADT_LET_CONST_REF(data, + GetCodeFromBuiltinSerializableAttrMap( + &ctx, kernel_dispatch_const_data)); + return data; + }; + ADT_LET_CONST_REF(anf_expr, lmbd.TryLambda({}, ConstructLambdaBody)); + return anf_expr.DumpToJsonString(); + } + + struct SerializedCodeGenResult { + std::string code_gen_lambda_str; + ap::axpr::Function kernel_dispatch_func; + ap::axpr::AttrMap kernel_dispatch_const_data; + }; + + adt::Result CodeGen(const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx) const { + const auto& op_declare = res_ptn_ir_op->op_declare; + ADT_LET_CONST_REF( + op_declare_data, + op_declare->cast_data()); + const auto& lambda = op_declare_data->code_gen_func(); + ADT_LET_CONST_REF(code_gen_result, + GetApKernelModule(lambda, match_ctx, res_ptn_ir_op)); + const auto& kernel_dispatch_func = code_gen_result->kernel_dispatch_func; + auto* data = &code_gen_result.shared_ptr()->kernel_dispatch_const_data; + ADT_RETURN_IF_ERR(InsertOrCheckApKernelInputIndexOrSlices( + data, res_ptn_ir_op, match_ctx)); + ADT_RETURN_IF_ERR(InsertOrCheckApKernelOutputIndexOrSlices( + data, res_ptn_ir_op, match_ctx)); + ADT_RETURN_IF_ERR( + InsertOrCheckApKernelInputName2Index(data, res_ptn_ir_op, match_ctx)); + ADT_RETURN_IF_ERR( + InsertOrCheckApKernelOutputName2Index(data, res_ptn_ir_op, match_ctx)); + return code_gen_result; + } + + adt::Result InsertOrCheckApKernelInputIndexOrSlices( + ap::axpr::AttrMap* object, + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx) const { + adt::List list; + using Ok = adt::Result; + auto DoEachIndex = [&](int64_t idx) -> Ok { + list->emplace_back(idx); + return adt::Ok{}; + }; + auto DoEachSlice = [&](int64_t start, int64_t end) -> Ok { + adt::List range{start, end}; + list->emplace_back(range); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitApKernelInputIndexOrSlice( + res_ptn_ir_op, match_ctx, DoEachIndex, DoEachSlice)); + const std::string key{"__builtin_ap_kernel_input_indexes_slices"}; + if ((*object)->Has(key)) { + ADT_LET_CONST_REF(old_list, (*object)->Get(key)); + ADT_CHECK(ap::axpr::SerializableValue{list} == old_list); // NOLINT + } else { + ADT_CHECK((*object)->Emplace(key, list)); + } + return adt::Ok{}; + } + + adt::Result InsertOrCheckApKernelOutputIndexOrSlices( + ap::axpr::AttrMap* object, + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx) const { + adt::List list; + using Ok = adt::Result; + auto DoEachIndex = [&](int64_t idx) -> Ok { + list->emplace_back(idx); + return adt::Ok{}; + }; + auto DoEachSlice = [&](int64_t start, int64_t end) -> Ok { + adt::List range{start, end}; + list->emplace_back(range); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitApKernelOutputIndexOrSlice( + res_ptn_ir_op, match_ctx, DoEachIndex, DoEachSlice)); + const std::string key{"__builtin_ap_kernel_output_indexes_slices"}; + if ((*object)->Has(key)) { + ADT_LET_CONST_REF(old_list, (*object)->Get(key)); + ADT_CHECK(ap::axpr::SerializableValue{list} == old_list); // NOLINT + } else { + ADT_CHECK((*object)->Emplace(key, list)); + } + return adt::Ok{}; + } + + adt::Result GetApKernelModule( + const ap::axpr::Value& lambda, + const GraphMatchCtx& match_ctx, + const DrrPackedIrOp& res_ptn_ir_op) const { + ADT_LET_CONST_REF(src_ptn_ctx, ctx_.drr_ctx->GetSourcePatternCtx()); + IrMatchCtx ir_match_ctx{src_ptn_ctx, match_ctx}; + ADT_LET_CONST_REF(arg_source_ctx, + MakeArgSourceCtx(match_ctx, res_ptn_ir_op)); + CodeGenCtx code_gen_ctx{ir_match_ctx, res_ptn_ir_op, arg_source_ctx}; + ApKernelDefineHelper helper{ctx_.drr_ctx->circlable_ref_list}; + ADT_LET_CONST_REF(result, helper.Interpret(lambda, code_gen_ctx)); + return result; + } + + adt::Result> MakeArgSourceCtx( + const GraphMatchCtx& match_ctx, + const DrrPackedIrOp& res_ptn_ir_op) const { + ap::code_gen::MatchedResultPatternHelper helper{match_ctx, + ctx_.drr_ctx}; + ap::code_gen::ArgSourceMaker maker{helper}; + ADT_LET_CONST_REF(arg_source_ctx, maker.MakeArgSourceCtx(res_ptn_ir_op)); + return arg_source_ctx; + } + + adt::Result ConvertApKernelModuleToAnfExpr( + const CodeModule& m) const { + return ap::code_module::ModuleToAxprHelper{}.ConvertModuleToAnfExpr(m); + } + + adt::Result GetKernelDispatchLambdaStr( + const ap::axpr::Function& + kernel_dispatch_func) const { + const auto& lambda = kernel_dispatch_func->lambda; + ap::axpr::AnfExpr anf_expr = ap::axpr::ConvertCoreExprToAnfExpr(lambda); + return anf_expr.DumpToJsonString(); + } + + adt::Result MakeApPatternFusionOp( + pir::PatternRewriter* rewriter, + pir::Value input, + std::size_t num_outputs, + const std::string& code_gen_lambda_str, + const std::string& infer_meta_lambda_str, + const std::string& kernel_dispatch_lambda_str, + const std::string& kernel_dispatch_const_data_lambda_str) const { + auto ap_unary = rewriter->Build( + input, + num_outputs, + code_gen_lambda_str, + infer_meta_lambda_str, + kernel_dispatch_lambda_str, + kernel_dispatch_const_data_lambda_str); + return ap_unary.out(); + } + + adt::Result> GetPackedOpOutputValues( + pir::PatternRewriter* rewriter, pir::Value combined_out) const { + auto split_op = rewriter->Build(combined_out); + return split_op.outputs(); + } + + template + adt::Result UpdateConstructedOpOutputsInReplaceCtx( + const GraphMatchCtx& match_ctx, + const std::vector& output_values, + const IrOpT& res_ptn_ir_op, + RewriteCtx* rewrite_ctx) const { + auto UpdateRewriteCtx = [&](const DrrIrValue& ir_value, + const std::vector& output_slice) + -> adt::Result { + return ir_value.Match( + [&](const DrrNativeIrValue& ir_value) -> adt::Result { + ADT_CHECK(output_slice.size() == 1); + const auto& k = ir_value->name; + const auto& v = output_slice.at(0); + ADT_CHECK(rewrite_ctx->name2native_value.emplace(k, v).second); + return adt::Ok{}; + }, + [&](const DrrPackedIrValue& ir_value) -> adt::Result { + const auto& k = ir_value->name; + const auto& v = output_slice; + ADT_CHECK(rewrite_ctx->name2packed_values.emplace(k, v).second); + return adt::Ok{}; + }); + }; + ADT_RETURN_IF_ERR(VisitEachMatchedDrrIrValueAndOutputSlice( + match_ctx, output_values, res_ptn_ir_op, UpdateRewriteCtx)); + return adt::Ok{}; + } + + template + adt::Result VisitEachMatchedDrrIrValueAndOutputSlice( + const GraphMatchCtx& match_ctx, + const std::vector& output_values, + const IrOpT& res_ptn_ir_op, + const YieldT& Yield) const { + ap::code_gen::MatchedResultPatternHelper helper{match_ctx, + ctx_.drr_ctx}; + return helper.VisitEachMatchedDrrIrValueAndOutputSlice( + output_values, res_ptn_ir_op, Yield); + } + + adt::Result GetResPtnNumPirValues( + const DrrIrValue& drr_ir_value, const GraphMatchCtx& match_ctx) const { + ap::code_gen::MatchedResultPatternHelper helper{match_ctx, + ctx_.drr_ctx}; + return helper.GetResPtnNumBirValues(drr_ir_value); + } + + adt::Result GetApKernelNumOutputs( + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx) const { + ap::code_gen::MatchedResultPatternHelper helper{match_ctx, + ctx_.drr_ctx}; + return helper.GetApKernelNumOutputs(res_ptn_ir_op); + } + + template + adt::Result VisitApKernelInputIndexOrSlice( + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx, + const DoEachIndexT& DoEachIndex, + const DoEachSliceT& DoEachSlice) const { + ap::code_gen::MatchedResultPatternHelper helper{match_ctx, + ctx_.drr_ctx}; + return helper.VisitApKernelInputIndexOrSlice( + res_ptn_ir_op, DoEachIndex, DoEachSlice); + } + + template + adt::Result VisitApKernelOutputIndexOrSlice( + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx, + const DoEachIndexT& DoEachIndex, + const DoEachSliceT& DoEachSlice) const { + ap::code_gen::MatchedResultPatternHelper helper{match_ctx, + ctx_.drr_ctx}; + return helper.VisitApKernelOutputIndexOrSlice( + res_ptn_ir_op, DoEachIndex, DoEachSlice); + } + + adt::Result InsertOrCheckApKernelInputName2Index( + ap::axpr::AttrMap* object, + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx) const { + ap::axpr::AttrMap name2idx; + int64_t idx = 0; + auto DoEachIrValue = + [&](const DrrIrValue& drr_ir_value) -> adt::Result { + ADT_CHECK(name2idx->Emplace(drr_ir_value.name(), idx)); + ++idx; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + VisitResPtnInputIrValueByResPtnIrOp(res_ptn_ir_op, DoEachIrValue)); + const std::string key{"__builtin_ap_kernel_input_name_to_index"}; + if ((*object)->Has(key)) { + ADT_LET_CONST_REF(old_name2idx_val, (*object)->Get(key)); + ADT_LET_CONST_REF(old_name2idx, + old_name2idx_val.template TryGet< + ap::axpr::AttrMap>()); + ADT_CHECK(old_name2idx->storage == name2idx->storage); + } else { + ADT_CHECK((*object)->Emplace(key, name2idx)); + } + return adt::Ok{}; + } + + adt::Result InsertOrCheckApKernelOutputName2Index( + ap::axpr::AttrMap* object, + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx) const { + ap::axpr::AttrMap name2idx; + int64_t idx = 0; + auto DoEachIrValue = + [&](const DrrIrValue& drr_ir_value) -> adt::Result { + ADT_CHECK(name2idx->Emplace(drr_ir_value.name(), idx)); + ++idx; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + VisitResPtnOutputIrValueByResPtnIrOp(res_ptn_ir_op, DoEachIrValue)); + const std::string key{"__builtin_ap_kernel_output_name_to_index"}; + if ((*object)->Has(key)) { + ADT_LET_CONST_REF(old_name2idx_val, (*object)->Get(key)); + ADT_LET_CONST_REF(old_name2idx, + old_name2idx_val.template TryGet< + ap::axpr::AttrMap>()); + ADT_CHECK(old_name2idx->storage == name2idx->storage); + } else { + ADT_CHECK((*object)->Emplace(key, name2idx)); + } + return adt::Ok{}; + } + + adt::Result InsertCombinedOp( + pir::PatternRewriter* rewriter, + const std::vector& inputs) const { + auto combined_op = rewriter->Build(inputs); + return combined_op.out(); + } + + adt::Result TrySetInsertPointer( + pir::PatternRewriter* rewriter, + const RewriteCtx& rewrite_ctx, + const std::vector& input_values, + const GraphMatchCtx& match_ctx) const { + ADT_LET_CONST_REF(opt_last_pir_op, + GetLastInputPirOp(rewriter->block(), input_values)); + if (opt_last_pir_op.has_value()) { + rewriter->SetInsertionPointAfter(opt_last_pir_op.value()); + } else { + ADT_RETURN_IF_ERR( + SetDefaultInsertPointer(rewriter, rewrite_ctx, match_ctx)); + } + return adt::Ok{}; + } + + adt::Result> GetLastInputPirOp( + pir::Block* block, const std::vector& input_values) const { + const auto& ops = [&] { + std::unordered_set ret; + for (const auto& value : input_values) { + if (!value) { + continue; + } + if (value.defining_op() != nullptr && + value.defining_op()->GetParent() == block) { + ret.insert(value.defining_op()); + } + } + return ret; + }(); + ADT_LET_CONST_REF(input_op2order_value, MakeMatchedOp2OrderValue(ops)); + auto OptOrderValue4Op = [&](pir::Operation* op) -> std::optional { + const auto iter = input_op2order_value.find(op); + if (iter == input_op2order_value.end()) { + return std::nullopt; + } + return iter->second; + }; + std::optional last_op; + std::optional op_order_value; + for (auto* op : ops) { + const auto& order_value = OptOrderValue4Op(op); + if (!order_value.has_value()) { + continue; + } + if (!op_order_value.has_value() || + op_order_value.value() < order_value.value()) { + op_order_value = order_value.value(); + last_op = op; + } + } + return last_op; + } + + adt::Result SetDefaultInsertPointer( + pir::PatternRewriter* rewriter, + const RewriteCtx& rewrite_ctx, + const GraphMatchCtx& match_ctx) const { + ADT_LET_CONST_REF(last_pir_op, GetLastMatchedPirOp(rewrite_ctx, match_ctx)); + rewriter->SetInsertionPointAfter(last_pir_op); + return adt::Ok{}; + } + + adt::Result GetLastMatchedPirOp( + const RewriteCtx& rewrite_ctx, const GraphMatchCtx& match_ctx) const { + std::optional last_op; + std::optional op_order_value; + auto UpdatePirOp = [&](pir::Operation* op) -> adt::Result { + ADT_LET_CONST_REF(order_value, rewrite_ctx.GetMatchedOpOrderValue(op)); + if (!op_order_value.has_value() || op_order_value.value() < order_value) { + op_order_value = order_value; + last_op = op; + } + return adt::Ok{}; + }; + auto UpdateLastOp = [&](const DrrGraphNode& op) -> adt::Result { + ADT_LET_CONST_REF(pir_node, match_ctx->GetSoleBigGraphNode(op)); + return pir_node.Match( + [&](const ap::paddle::NativeIrOp& ir_op) -> adt::Result { + return UpdatePirOp(ir_op.op); + }, + [&](const ap::paddle::PackedIrOp& ir_op) -> adt::Result { + return UpdatePirOp(ir_op.fusion_op); + }, + [](const auto&) -> adt::Result { return adt::Ok{}; }); + }; + ADT_RETURN_IF_ERR(match_ctx->VisitSmallGraphNode(UpdateLastOp)); + ADT_CHECK(last_op.has_value()); + return last_op.value(); + } + + template + adt::Result VisitResPtnInputIrValueByResPtnIrOp( + const IrOpT& res_ptn_ir_op, const YieldT& Yield) const { + ap::drr::ResultPatternHelper helper{ctx_.drr_ctx}; + return helper.VisitResPtnInputIrValueByResPtnIrOp(res_ptn_ir_op, Yield); + } + + template + adt::Result VisitResPtnOutputIrValueByResPtnIrOp( + const IrOpT& res_ptn_ir_op, const YieldT& Yield) const { + ap::drr::ResultPatternHelper helper{ctx_.drr_ctx}; + return helper.VisitResPtnOutputIrValueByResPtnIrOp(res_ptn_ir_op, Yield); + } + + std::optional SrcPtnIrValue4ResPtnIrValue( + const DrrIrValue& res_ptn_ir_value) const { + ap::drr::ResultPatternHelper helper{ctx_.drr_ctx}; + return helper.SrcPtnIrValue4ResPtnIrValue(res_ptn_ir_value); + } + + template + adt::Result InsertInputPirValueToReplaceCtx( + const IrOpT& res_ptn_ir_op, + RewriteCtx* rewrite_ctx, + const GraphMatchCtx& match_ctx) const { + using Ok = adt::Result; + auto InitInput = [&](const DrrIrValue& drr_ir_value) -> Ok { + return drr_ir_value.Match( + [&](const DrrNativeIrValue& res_ptn_ir_value) -> Ok { + ADT_RETURN_IF_ERR(InsertNativeIrValueToReplaceCtx( + res_ptn_ir_value, rewrite_ctx, match_ctx)); + return adt::Ok{}; + }, + [&](const DrrPackedIrValue& res_ptn_ir_value) -> Ok { + ADT_RETURN_IF_ERR(InsertPackedIrValueToReplaceCtx( + res_ptn_ir_value, rewrite_ctx, match_ctx)); + return adt::Ok{}; + }); + }; + ADT_RETURN_IF_ERR( + VisitResPtnInputIrValueByResPtnIrOp(res_ptn_ir_op, InitInput)); + return adt::Ok{}; + } + + adt::Result InsertNativeIrValueToReplaceCtx( + const DrrNativeIrValue& res_ptn_ir_value, + RewriteCtx* rewrite_ctx, + const GraphMatchCtx& match_ctx) const { + const auto iter = + rewrite_ctx->name2native_value.find(res_ptn_ir_value->name); + if (iter != rewrite_ctx->name2native_value.end()) { + return adt::Ok{}; + } + const auto& opt_ir_value = SrcPtnIrValue4ResPtnIrValue(res_ptn_ir_value); + ADT_CHECK(opt_ir_value.has_value()); + const auto& ir_value = opt_ir_value.value(); + ADT_LET_CONST_REF(pir_node, + match_ctx->GetSoleBigGraphNode(ir_value.node())); + ADT_LET_CONST_REF(pir_value, + pir_node.template TryGet()) + << adt::errors::TypeError{ + "pir_node is not an ap::paddle::NativeIrValue"}; + rewrite_ctx->name2native_value[ir_value.name()] = pir_value.value; + return adt::Ok{}; + } + + adt::Result InsertPackedIrValueToReplaceCtx( + const DrrPackedIrValue& res_ptn_ir_value, + RewriteCtx* rewrite_ctx, + const GraphMatchCtx& match_ctx) const { + using Ok = adt::Result; + const auto iter = + rewrite_ctx->name2packed_values.find(res_ptn_ir_value->name); + if (iter != rewrite_ctx->name2packed_values.end()) { + return adt::Ok{}; + } + const auto& opt_ir_value = SrcPtnIrValue4ResPtnIrValue(res_ptn_ir_value); + ADT_CHECK(opt_ir_value.has_value()); + const auto& ir_value = opt_ir_value.value(); + auto* vec = &rewrite_ctx->name2packed_values[ir_value.name()]; + ADT_CHECK(vec->empty()); + auto AppendNode = [&](const PirNode& pir_node) -> Ok { + ADT_LET_CONST_REF(pir_value, + pir_node.template TryGet()) + << adt::errors::TypeError{ + "pir_node is not an ap::paddle::NativeIrValue"}; + vec->emplace_back(pir_value.value); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + match_ctx->VisitPackedBigGraphIrValueNode(ir_value.node(), AppendNode)); + return adt::Ok{}; + } + + adt::Result CastToPirNativeIrValue( + const PirNode& pir_node) const { + using RetT = adt::Result; + return pir_node.Match( + [&](const typename PirNode::native_value_type& bir_value) -> RetT { + return bir_value; + }, + [&](const typename PirNode::ref_value_type& ref_value) -> RetT { + return ref_value.GetOwnerNativeIrValue(); + }, + [&](const auto&) -> RetT { + return adt::errors::TypeError{ + "pir_node is not an PirNode::native_value_type or " + "PirNode::ref_value_type"}; + }); + } + + adt::Result> GetMatchedPirInputsOfRestPtnPackedIrOp( + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx) const { + std::vector ret; + auto CollectInput = [&](const PirNode& pir_node) -> adt::Result { + ADT_LET_CONST_REF(pir_value, CastToPirNativeIrValue(pir_node)); + ret.emplace_back(pir_value.value); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitMatchedPirInputOfRestPtnPackedIrOp( + res_ptn_ir_op, match_ctx, CollectInput)); + return ret; + } + + template + adt::Result VisitMatchedPirInputOfRestPtnPackedIrOp( + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx, + const YieldT& Yield) const { + ap::code_gen::MatchedResultPatternHelper helper{match_ctx, + ctx_.drr_ctx}; + return helper.VisitMatchedBirInputOfRestPtnPackedIrOp(res_ptn_ir_op, Yield); + } + + adt::Result> GetMatchedPirOutputsOfRestPtnPackedIrOp( + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx) const { + std::vector ret; + using Ok = adt::Result; + auto CollectOutput = [&](const PirNode& pir_node) -> Ok { + ADT_LET_CONST_REF(pir_value, CastToPirNativeIrValue(pir_node)); + ret.emplace_back(pir_value.value); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitMatchedPirOutputOfRestPtnPackedIrOp( + res_ptn_ir_op, match_ctx, CollectOutput)); + return ret; + } + + template + adt::Result VisitMatchedPirOutputOfRestPtnPackedIrOp( + const DrrPackedIrOp& res_ptn_ir_op, + const GraphMatchCtx& match_ctx, + const YieldT& Yield) const { + ap::code_gen::MatchedResultPatternHelper helper{match_ctx, + ctx_.drr_ctx}; + return helper.VisitMatchedBirOutputOfRestPtnPackedIrOp(res_ptn_ir_op, + Yield); + } + + adt::Result> GetNativeOpInputValues( + const DrrNativeIrOp& res_ptn_ir_op, const RewriteCtx& rewrite_ctx) const { + std::vector ret; + auto CollectValues = [&](pir::Value value) -> adt::Result { + ret.push_back(value); + return adt::Ok{}; + }; + auto VisitAndCollect = + [&](const DrrIrValue& drr_ir_value) -> adt::Result { + return VisitPirValueByIrValue(drr_ir_value, rewrite_ctx, CollectValues); + }; + ADT_RETURN_IF_ERR( + VisitResPtnInputIrValueByResPtnIrOp(res_ptn_ir_op, VisitAndCollect)); + return ret; + } + + adt::Result> GetPackedOpInputValues( + const DrrPackedIrOp& res_ptn_ir_op, const RewriteCtx& rewrite_ctx) const { + std::vector ret; + auto CollectValues = [&](pir::Value value) -> adt::Result { + ret.push_back(value); + return adt::Ok{}; + }; + auto VisitAndCollect = + [&](const DrrIrValue& drr_ir_value) -> adt::Result { + return VisitPirValueByIrValue(drr_ir_value, rewrite_ctx, CollectValues); + }; + ADT_RETURN_IF_ERR( + VisitResPtnInputIrValueByResPtnIrOp(res_ptn_ir_op, VisitAndCollect)); + return ret; + } + + template + adt::Result VisitPirValueByIrValue(const DrrIrValue& ir_value, + const RewriteCtx& rewrite_ctx, + const YieldT& Yield) const { + ADT_RETURN_IF_ERR(ir_value.Match( + [&](const DrrNativeIrValue& ir_value) -> adt::Result { + const auto& name = ir_value->name; + const auto& iter = rewrite_ctx.name2native_value.find(name); + ADT_CHECK(iter != rewrite_ctx.name2native_value.end()); + return Yield(iter->second); + }, + [&](const DrrPackedIrValue& ir_value) -> adt::Result { + const auto& name = ir_value->name; + const auto& iter = rewrite_ctx.name2packed_values.find(name); + ADT_CHECK(iter != rewrite_ctx.name2packed_values.end()); + for (const auto& value : iter->second) { + ADT_RETURN_IF_ERR(Yield(value)); + } + return adt::Ok{}; + })); + return adt::Ok{}; + } +}; + +struct ConstraintApplier { + adt::Result Match(const DrrCtx& drr_ctx, + const GraphMatchCtx& graph_match_ctx) { + if (!drr_ctx->constraint_func.has_value()) { + return true; + } + const auto& constraint_func = drr_ctx->constraint_func.value(); + ADT_CHECK(drr_ctx->source_pattern_ctx.has_value()); + IrMatchCtx ir_match_ctx{drr_ctx->source_pattern_ctx.value(), + graph_match_ctx}; + const auto& args = GetConstraintFuncArgs(ir_match_ctx); + ApDrrHelper ap_drr_helper{drr_ctx->circlable_ref_list}; + ADT_LET_CONST_REF(is_match_val, + ap_drr_helper.Interpret(constraint_func, args)); + ADT_CHECK(drr_ctx->pass_name.has_value()); + ADT_LET_CONST_REF(is_match, is_match_val.template CastTo()) + << adt::errors::TypeError{ + std::string() + + "constraint function should return a bool (not " + + ap::axpr::GetTypeName(is_match_val) + + "). pass_name: " + drr_ctx->pass_name.value()}; + return is_match; + } + + std::vector GetConstraintFuncArgs( + const IrMatchCtx& ir_match_ctx) { + ap::ir_match::OpMatchCtx op_match_ctx{ir_match_ctx.shared_ptr()}; + ap::ir_match::TensorMatchCtx tensor_match_ctx{ + ir_match_ctx.shared_ptr()}; + return std::vector{ + ap::ir_match::GetOpMatchCtxClass().New( + op_match_ctx), + ap::ir_match::GetTensorMatchCtxClass().New( + tensor_match_ctx), + }; + } +}; + +struct NativeOpAnchorApLowerFusionOpPatternMatcher { + const ApLowerFusionOpPatternCtx& ctx_; + + using Self = NativeOpAnchorApLowerFusionOpPatternMatcher; + + static adt::Result> Match(const DrrCtx& drr_ctx, + pir::Operation* op) { + ADT_LET_CONST_REF(pattern_ctx, + ApLowerFusionOpPatternCtx::MakeFromDrrCtx( + drr_ctx, + /*steps_limit=*/std::nullopt, + std::make_shared(drr_ctx))); + Self matcher{pattern_ctx}; + return matcher.GetMatchCtx(op); + } + + adt::Result> GetMatchCtx( + pir::Operation* op) const { + DefaultGraph drr_graph{}; + DefaultGraph pir_graph{}; + auto* parent_block = op->GetParent(); + ADT_CHECK(parent_block != nullptr); + auto* parent_op = parent_block->GetParentOp(); + ADT_CHECK(!parent_op->isa()); + ADT_CHECK(ctx_.native_op_anchor.has_value()); + const auto& native_op_anchor = ctx_.native_op_anchor.value(); + { + ADT_LET_CONST_REF( + anchor_topo_cstr, + drr_graph.GetSmallGraphNodeTopoCstr(native_op_anchor->node)); + ap::paddle::NativeIrOp native_ir_op{op}; + ADT_LET_CONST_REF(topo_satisfy_constraint, + pir_graph.TopoSatisfy(native_ir_op, anchor_topo_cstr)); + bool satisfy_constraint = topo_satisfy_constraint; + if (satisfy_constraint) { + ap::graph::NodeDescriptor node_descriptor{}; + ADT_LET_CONST_REF(attrs_satisfy_constraint, + node_descriptor.AttrsSatisfyIfBothAreOpsOrValues( + native_ir_op, native_op_anchor->node)); + satisfy_constraint = attrs_satisfy_constraint; + if (!attrs_satisfy_constraint) { + ap::graph::NodeDescriptor drr_node_descriptor{}; + ap::graph::NodeDescriptor pir_node_descriptor{}; + LOG(ERROR) << "pir_node_descriptor.AttrsSatisfyIfBothAreOpsOrValues()" + " test failed. drr_node: " + << drr_node_descriptor.DebugId(native_op_anchor->node) + << ", pir_node: " + << pir_node_descriptor.DebugId(native_ir_op); + } + } else { + ap::graph::NodeDescriptor drr_node_descriptor{}; + ap::graph::NodeDescriptor pir_node_descriptor{}; + LOG(ERROR) << "pir_graph.TopoSatisfy() test failed. drr_node: " + << drr_node_descriptor.DebugId(native_op_anchor->node) + << ", pir_node: " + << pir_node_descriptor.DebugId(native_ir_op); + } + ADT_CHECK(satisfy_constraint) << adt::errors::ValueError{ + std::string() + + "pir_graph.TopoSatisfy() or " + "node_descriptor.AttrsSatisfyIfBothAreOpsOrValues() test failed. " + "drr_pass_name: " + + ctx_.drr_ctx->pass_name.value()}; + } + ADT_LET_CONST_REF(drr_op_result_anchor, + GetFirstNativeDrrIrOpResult(native_op_anchor)); + ADT_LET_CONST_REF(pir_op_result_anchor, GetFirstNativePirIrOpResult(op)); + { + ADT_LET_CONST_REF( + drr_op_result_anchor_topo_cstr, + drr_graph.GetSmallGraphNodeTopoCstr(drr_op_result_anchor)); + ADT_LET_CONST_REF(topo_satisfy_constraint, + pir_graph.TopoSatisfy(pir_op_result_anchor, + drr_op_result_anchor_topo_cstr)); + bool satisfy_constraint = topo_satisfy_constraint; + if (satisfy_constraint) { + ap::graph::NodeDescriptor node_descriptor{}; + ADT_LET_CONST_REF(attrs_satisfy_constraint, + node_descriptor.AttrsSatisfyIfBothAreOpsOrValues( + pir_op_result_anchor, drr_op_result_anchor)); + satisfy_constraint = attrs_satisfy_constraint; + } + ADT_CHECK(satisfy_constraint) << adt::errors::ValueError{ + std::string() + + "TopoSatisfy() or AttrsSatisfyIfBothAreOpsOrValues() " + "test failed. pir_op_result_anchor: " + + DebugId(pir_op_result_anchor) + + ", drr_op_result_anchor: " + DebugId(drr_op_result_anchor) + + ", pir_op: " + DebugId(ap::paddle::NativeIrOp{op}) + + ", drr_native_op: " + DebugId(native_op_anchor->node) + "."}; + } + std::optional opt_graph_match_ctx; + { + NativeORGraph pir_native_operand_result_graph{}; + NativeORGraph drr_native_operand_result_graph{}; + using NativeOR = ap::drr::topo_kind::NativeOperandAndResult; + ap::ir_match::GraphMatcher graph_matcher( + pir_native_operand_result_graph, drr_native_operand_result_graph); + ADT_LET_CONST_REF(graph_ctx, + graph_matcher.MatchByAnchor(pir_op_result_anchor, + drr_op_result_anchor)); + opt_graph_match_ctx = graph_ctx; + ADT_LET_CONST_REF(graph_matched, + graph_matcher.IsGraphMatched( + opt_graph_match_ctx.value(), drr_op_result_anchor)); + ADT_CHECK(graph_matched) << adt::errors::MismatchError{}; + } + ADT_CHECK(opt_graph_match_ctx.has_value()); + { + ADT_LET_CONST_REF( + ref_match_ctx, + GetRefMatchCtx(opt_graph_match_ctx.value(), drr_op_result_anchor)); + RefAugmentedGraph pir_augmented_graph{ref_match_ctx}; + using RefAugmented = ap::drr::topo_kind::RefAugmented; + using Default = ap::drr::topo_kind::Default; + ap::ir_match::GraphMatcher graph_matcher( + pir_augmented_graph, drr_graph); + ADT_RETURN_IF_ERR(graph_matcher.UpdateByConnectionsUntilDone( + &opt_graph_match_ctx.value(), drr_op_result_anchor)); + auto UpdateUntilDone = [&](auto* ctx) -> adt::Result { + ADT_RETURN_IF_ERR(graph_matcher.UpdateByConnectionsUntilDone( + ctx, drr_op_result_anchor)); + ADT_LET_CONST_REF( + graph_matched, + graph_matcher.IsGraphMatched(*ctx, drr_op_result_anchor)); + if (graph_matched) { + return adt::Break{}; + } else { + return adt::Continue{}; + } + }; + ADT_RETURN_IF_ERR(graph_matcher.InplaceForcePickOneLastUndetermined( + &opt_graph_match_ctx.value(), UpdateUntilDone)); + ADT_LET_CONST_REF(graph_matched, + graph_matcher.IsGraphMatched( + opt_graph_match_ctx.value(), drr_op_result_anchor)); + if (!graph_matched) { + opt_graph_match_ctx = std::nullopt; + } + } + if (!opt_graph_match_ctx.has_value()) { + return std::nullopt; + } + ADT_LET_CONST_REF( + match, + ConstraintApplier{}.Match(ctx_.drr_ctx, opt_graph_match_ctx.value())); + if (!match) { + return std::nullopt; + } + return opt_graph_match_ctx; + } + + std::string DebugId(const PirNode& pir_node) const { + return ap::graph::NodeDescriptor{}.DebugId(pir_node); + } + + std::string DebugId(const DrrGraphNode& drr_node) const { + return ap::graph::NodeDescriptor{}.DebugId(drr_node); + } + + template + using AllOperandAndResultGraph = + ap::graph::GraphDescriptor; + + using RefNodeInfo = + ap::ir_match::RefNodeInfo; + using RefMatchCtx = + ap::ir_match::RefMatchCtx; + + adt::Result GetRefMatchCtx(const GraphMatchCtx& graph_match_ctx, + const DrrGraphNode& anchor) const { + AllOperandAndResultGraph pir_graph{}; + AllOperandAndResultGraph drr_graph{}; + using AllOR = ap::drr::topo_kind::AllOperandAndResult; + ap::ir_match::GraphMatcher graph_matcher(pir_graph, + drr_graph); + RefMatchCtx ref_match_ctx{}; + using Ok = adt::Result; + auto DoEachMismatched = [&](const DrrGraphNode& node) -> Ok { + ADT_LET_CONST_REF(drr_node, node.Get()); + return drr_node.Match( + [&](const DrrOptPackedIrOpResult& op_result) -> Ok { + ADT_LET_CONST_REF(ref_node_info, + GetRefNodeInfo(graph_match_ctx, op_result)); + if (ref_node_info.has_value()) { + ADT_RETURN_IF_ERR( + ref_match_ctx->AddRefNodeInfo(ref_node_info.value())); + } + return adt::Ok{}; + }, + [&](const DrrOptPackedIrOpOperand& impl) -> Ok { + // do nothing. + return adt::Ok{}; + }, + [&](const DrrPackedIrOpOperand& impl) -> Ok { + // do nothing. + return adt::Ok{}; + }, + [&](const DrrPackedIrOpResult& impl) -> Ok { + // do nothing. + return adt::Ok{}; + }, + [&](const auto& impl) -> Ok { + const char* type_name = typeid(std::decay_t).name(); + return adt::errors::ValueError{ + std::string() + + "GetRefValue2Operands unexpected mismatched DrrGraphNode: " + + type_name}; + }); + }; + ADT_RETURN_IF_ERR(graph_matcher.VisitMisMatchedNodes( + graph_match_ctx, anchor, DoEachMismatched)); + return ref_match_ctx; + } + + adt::Result> GetRefNodeInfo( + const GraphMatchCtx& graph_match_ctx, + const DrrOptPackedIrOpResult& op_result) const { + ADT_LET_CONST_REF(opt_inner_ref_node_info, + GetInnerRefNodeInfo(graph_match_ctx, op_result)); + if (opt_inner_ref_node_info.has_value()) { + return opt_inner_ref_node_info.value(); + } + ADT_LET_CONST_REF(opt_output_ref_node_info, + GetOutputRefNodeInfo(graph_match_ctx, op_result)); + if (opt_output_ref_node_info.has_value()) { + return opt_output_ref_node_info.value(); + } + ADT_LET_CONST_REF(opt_input_ref_node_info, + GetInputRefNodeInfo(graph_match_ctx, op_result)); + if (opt_input_ref_node_info.has_value()) { + return opt_input_ref_node_info.value(); + } + return std::nullopt; + } + + adt::Result> GetInnerRefNodeInfo( + const GraphMatchCtx& graph_match_ctx, + const DrrOptPackedIrOpResult& drr_op_result) const { + DefaultGraph default_pir_graph{}; + AllOperandAndResultGraph all_o_r_drr_graph{}; + const auto& topo_match_ctx = graph_match_ctx->topo_match_ctx; + ADT_LET_CONST_REF( + drr_op_operand, + all_o_r_drr_graph.CastSoleUnignoredInput( + drr_op_result)); + { + ADT_LET_CONST_REF(num_drr_op_result_downstreams, + all_o_r_drr_graph.GetNumOutputs(drr_op_result)); + if (num_drr_op_result_downstreams == 0) { + return std::nullopt; + } + ADT_LET_CONST_REF(num_drr_op_operand_upstreams, + all_o_r_drr_graph.GetNumInputs(drr_op_operand)); + if (num_drr_op_operand_upstreams == 0) { + return std::nullopt; + } + ADT_CHECK(num_drr_op_operand_upstreams == 1); + } + ADT_LET_CONST_REF( + drr_op_operand_upstream, + all_o_r_drr_graph.CastSoleUnignoredInput( + drr_op_operand)); + ADT_LET_CONST_REF( + pir_op_operand_upstream, + topo_match_ctx->GetSoleBigGraphNode(drr_op_operand_upstream->node)); + ADT_LET_CONST_REF(pir_native_ir_value, + CastPirSoleOutput( + default_pir_graph, pir_op_operand_upstream)); + adt::List pir_op_operands{}; + { + auto DoEachDownstream = + [&](const DrrGraphNode& node) -> adt::Result { + ADT_LET_CONST_REF(drr_node, node.Get()); + ADT_CHECK(drr_node.template Has()); + ADT_LET_CONST_REF(pir_node, topo_match_ctx->GetSoleBigGraphNode(node)); + ADT_LET_CONST_REF(pir_native_ir_op_operand, + pir_node.TryGet()); + ADT_LET_CONST_REF(cur_pir_native_ir_value, + CastPirSoleInput( + default_pir_graph, pir_native_ir_op_operand)); + if (cur_pir_native_ir_value == pir_native_ir_value) { + pir_op_operands->push_back(pir_native_ir_op_operand); + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(all_o_r_drr_graph.VisitDownstreamNodes( + drr_op_result->node, DoEachDownstream)); + if (pir_op_operands->empty()) { + return std::nullopt; + } + } + return RefNodeInfo{pir_native_ir_value, pir_op_operands}; + } + + adt::Result> GetOutputRefNodeInfo( + const GraphMatchCtx& graph_match_ctx, + const DrrOptPackedIrOpResult& drr_op_result) const { + DefaultGraph default_drr_graph{}; + DefaultGraph default_pir_graph{}; + AllOperandAndResultGraph all_o_r_drr_graph{}; + const auto& topo_match_ctx = graph_match_ctx->topo_match_ctx; + ADT_LET_CONST_REF( + drr_op_operand, + all_o_r_drr_graph.CastSoleUnignoredInput( + drr_op_result)); + { + ADT_LET_CONST_REF(num_drr_op_result_downstreams, + all_o_r_drr_graph.GetNumOutputs(drr_op_result)); + if (num_drr_op_result_downstreams != 0) { + return std::nullopt; + } + ADT_LET_CONST_REF(num_drr_op_operand_upstreams, + all_o_r_drr_graph.GetNumInputs(drr_op_operand)); + if (num_drr_op_operand_upstreams == 0) { + return std::nullopt; + } + ADT_CHECK(num_drr_op_operand_upstreams == 1); + } + ADT_LET_CONST_REF( + drr_op_operand_upstream, + all_o_r_drr_graph.CastSoleUnignoredInput( + drr_op_operand)); + ADT_LET_CONST_REF( + pir_op_operand_upstream, + topo_match_ctx->GetSoleBigGraphNode(drr_op_operand_upstream->node)); + ADT_LET_CONST_REF(pir_native_ir_value, + CastPirSoleOutput( + default_pir_graph, pir_op_operand_upstream)); + ADT_LET_CONST_REF( + drr_ir_value, + default_drr_graph.CastSoleUnignoredInput( + drr_op_operand)); + std::unordered_set excluded; + { + auto DoEachDownstream = + [&](const DrrGraphNode& node) -> adt::Result { + ADT_LET_CONST_REF(drr_node, node.Get()); + if (!drr_node.template Has()) { + return adt::Ok{}; + } + ADT_LET_CONST_REF(pir_node, topo_match_ctx->GetSoleBigGraphNode(node)); + ADT_LET_CONST_REF(pir_op_operand, + pir_node.TryGet()); + ADT_CHECK(excluded.emplace(pir_op_operand).second); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(default_drr_graph.VisitDownstreamNodes( + drr_ir_value->node, DoEachDownstream)); + } + adt::List pir_op_operands{}; + { + auto DoEachDownstream = [&](const PirNode& node) -> adt::Result { + if (!node.template Has()) { + return adt::Ok{}; + } + ADT_LET_CONST_REF(pir_op_operand, node.TryGet()); + if (excluded.count(pir_op_operand) == 0) { + pir_op_operands->push_back(pir_op_operand); + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(default_pir_graph.VisitDownstreamNodes( + pir_native_ir_value, DoEachDownstream)); + } + return RefNodeInfo{pir_native_ir_value, pir_op_operands}; + } + + adt::Result> GetInputRefNodeInfo( + const GraphMatchCtx& graph_match_ctx, + const DrrOptPackedIrOpResult& drr_op_result) const { + DefaultGraph default_pir_graph{}; + AllOperandAndResultGraph all_o_r_drr_graph{}; + const auto& topo_match_ctx = graph_match_ctx->topo_match_ctx; + ADT_LET_CONST_REF( + drr_op_operand, + all_o_r_drr_graph.CastSoleUnignoredInput( + drr_op_result)); + { + ADT_LET_CONST_REF(num_drr_op_result_downstreams, + all_o_r_drr_graph.GetNumOutputs(drr_op_result)); + if (num_drr_op_result_downstreams == 0) { + return std::nullopt; + } + ADT_LET_CONST_REF(num_drr_op_operand_upstreams, + all_o_r_drr_graph.GetNumInputs(drr_op_operand)); + if (num_drr_op_operand_upstreams != 0) { + return std::nullopt; + } + } + std::optional pir_native_ir_value; + adt::List pir_op_operands{}; + { + auto DoEachDownstream = + [&](const DrrGraphNode& node) -> adt::Result { + ADT_LET_CONST_REF(drr_node, node.Get()); + ADT_CHECK(drr_node.template Has()); + ADT_LET_CONST_REF(pir_node, topo_match_ctx->GetSoleBigGraphNode(node)); + ADT_LET_CONST_REF(pir_native_ir_op_operand, + pir_node.TryGet()); + ADT_LET_CONST_REF(cur_pir_native_ir_value, + CastPirSoleInput( + default_pir_graph, pir_native_ir_op_operand)); + if (!pir_native_ir_value.has_value()) { + ADT_LET_CONST_REF( + cur_pir_native_ir_value_upstream, + GetPirSoleInput(default_pir_graph, cur_pir_native_ir_value)); + if (!cur_pir_native_ir_value_upstream + .template Has()) { + return adt::Ok{}; + } + pir_native_ir_value = cur_pir_native_ir_value; + } + ADT_CHECK(cur_pir_native_ir_value == pir_native_ir_value.value()); + pir_op_operands->push_back(pir_native_ir_op_operand); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(all_o_r_drr_graph.VisitDownstreamNodes( + drr_op_result->node, DoEachDownstream)); + } + if (!pir_native_ir_value.has_value()) { + return std::nullopt; + } + if (pir_op_operands->empty()) { + return std::nullopt; + } + return RefNodeInfo{pir_native_ir_value.value(), pir_op_operands}; + } + + template + adt::Result CastPirSoleOutput(const GraphT& pir_graph, + const PirNode& node) const { + std::optional opt_pir_node{}; + auto DoEachDownstream = + [&](const PirNode& downstream) -> adt::Result { + ADT_LET_CONST_REF(pir_node_impl, downstream.TryGet()); + ADT_CHECK(!opt_pir_node.has_value()); + opt_pir_node = pir_node_impl; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(pir_graph.VisitDownstreamNodes(node, DoEachDownstream)); + ADT_CHECK(opt_pir_node.has_value()); + return opt_pir_node.value(); + } + + template + adt::Result CastPirSoleInput(const GraphT& pir_graph, + const PirNode& node) const { + std::optional opt_pir_node{}; + auto DoEachUpstream = [&](const PirNode& upstream) -> adt::Result { + ADT_LET_CONST_REF(pir_node_impl, upstream.TryGet()); + ADT_CHECK(!opt_pir_node.has_value()); + opt_pir_node = pir_node_impl; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(pir_graph.VisitUpstreamNodes(node, DoEachUpstream)); + ADT_CHECK(opt_pir_node.has_value()); + return opt_pir_node.value(); + } + + template + adt::Result GetPirSoleInput(const GraphT& pir_graph, + const PirNode& node) const { + std::optional opt_pir_node{}; + auto DoEachUpstream = [&](const PirNode& upstream) -> adt::Result { + ADT_CHECK(!opt_pir_node.has_value()); + opt_pir_node = upstream; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(pir_graph.VisitUpstreamNodes(node, DoEachUpstream)); + ADT_CHECK(opt_pir_node.has_value()); + return opt_pir_node.value(); + } + + template + adt::Result GetNumPirOutputs(const GraphT& pir_graph, + const PirNode& node) const { + std::size_t num_outputs = 0; + auto DoEachDownstream = + [&](const PirNode& downstream) -> adt::Result { + ++num_outputs; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(pir_graph.VisitDownstreamNodes(node, DoEachDownstream)); + return num_outputs; + } + + template + adt::Result GetNumPirInputs(const GraphT& pir_graph, + const PirNode& node) const { + std::size_t num_inputs = 0; + auto DoEachUpstream = [&](const PirNode& upstream) -> adt::Result { + ++num_inputs; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(pir_graph.VisitUpstreamNodes(node, DoEachUpstream)); + return num_inputs; + } + + adt::Result GetFirstNativeDrrIrOpResult( + const DrrNativeIrOp& op) const { + ADT_LET_CONST_REF(downstreams, op->node.DownstreamNodes()); + ADT_CHECK(downstreams.size() > 0); + using List = adt::List; + using Vec = ap::graph::IndexedTag; + ADT_LET_CONST_REF(indexed_list, downstreams.template TryGet()); + return indexed_list.data->at(0); + } + + adt::Result GetFirstNativePirIrOpResult(pir::Operation* op) const { + ADT_CHECK(!op->isa()); + ADT_CHECK(op->num_results() > 0); + pir::Value value = op->result(0); + ap::paddle::NativeIrOpResult ir_op_result{ + pir::OpResult::dyn_cast_from(value)}; + return ir_op_result; + } +}; + +struct OpEraseHelepr { + ap::drr::SourcePatternCtx source_pattern_ctx_; + + adt::Result EraseUnusedOps(pir::PatternRewriter* rewriter, + const GraphMatchCtx& graph_match_ctx) { + using Ok = adt::Result; + auto TryErase = [&](const DrrGraphNode& drr_graph_node) -> Ok { + ADT_LET_CONST_REF(drr_node, drr_graph_node.Get()); + return EraseIfUnused(drr_node, rewriter, graph_match_ctx); + }; + ADT_RETURN_IF_ERR(ReversedVisitSrcPtnGraph(TryErase)); + return adt::Ok{}; + } + + private: + adt::Result EraseIfUnused(const DrrNode& drr_node, + pir::PatternRewriter* rewriter, + const GraphMatchCtx& graph_match_ctx) { + using Ok = adt::Result; + return drr_node.Match( + [&](const DrrNativeIrOp&) -> Ok { + return EraseOpIfUnused(drr_node, rewriter, graph_match_ctx); + }, + [&](const DrrPackedIrOp&) -> Ok { + return EraseOpIfUnused(drr_node, rewriter, graph_match_ctx); + }, + [&](const DrrOptPackedIrOp&) -> Ok { + return EraseOpIfUnused(drr_node, rewriter, graph_match_ctx); + }, + [&](const auto& impl) -> Ok { + // Do nothing. + return adt::Ok{}; + }); + } + + adt::Result EraseOpIfUnused(const DrrNode& drr_node, + pir::PatternRewriter* rewriter, + const GraphMatchCtx& graph_match_ctx) { + ADT_LET_CONST_REF(pir_node, + graph_match_ctx->GetSoleBigGraphNode(drr_node.node())); + using Ok = adt::Result; + return pir_node.Match( + [&](const ap::paddle::NativeIrOp& ir_op) -> Ok { + return ErasePirOpIfUnused(ir_op.op, rewriter); + }, + [&](const ap::paddle::PackedIrOp& ir_op) -> Ok { + return ErasePirOpIfUnused(ir_op.fusion_op, rewriter); + }, + [&](const auto&) -> Ok { + // Do nothing. + return adt::Ok{}; + }); + } + + adt::Result ErasePirOpIfUnused(const pir::Operation* op, + pir::PatternRewriter* rewriter) { + auto* mut_op = const_cast(op); + if (mut_op->use_empty()) { + rewriter->EraseOp(mut_op); + } + return adt::Ok{}; + } + + template + adt::Result ReversedVisitSrcPtnGraph(const YieldT& Yield) { + std::list sinks; + for (const auto& drr_node : source_pattern_ctx_->node_arena->nodes()) { + const auto& drr_graph_node = drr_node.node(); + ADT_LET_CONST_REF(downstreams, drr_graph_node.DownstreamNodes()); + if (downstreams.size() == 0) { + sinks.push_back(drr_graph_node); + } + } + using Ok = adt::Result; + ap::drr::DefaultDrrGraphDescriptor graph{}; + auto VisitPrev = [&](const DrrGraphNode& node, const auto& Yield) -> Ok { + return graph.VisitDownstreamNodes(node, Yield); + }; + auto VisitNext = [&](const DrrGraphNode& node, const auto& Yield) -> Ok { + return graph.VisitUpstreamNodes(node, Yield); + }; + ap::adt::TopoWalker walker{VisitPrev, VisitNext}; + ADT_RETURN_IF_ERR(walker(sinks.begin(), sinks.end(), Yield)); + return adt::Ok{}; + } +}; + +class NativeOpAnchorApLowerFusionOpPattern : public pir::RewritePattern { + private: + ApLowerFusionOpPatternCtx ctx_; + ApRewriter ap_rewriter_; + mutable std::size_t times_; + + public: + NativeOpAnchorApLowerFusionOpPattern(pir::IrContext* ir_context, + const ApLowerFusionOpPatternCtx& ctx) + : pir::RewritePattern(ctx.anchor_op_name, 1, ir_context, {}), + ctx_(ctx), + times_(0), + ap_rewriter_(ctx, &NativeOpAnchorApLowerFusionOpPatternMatcher::Match) { + } + + bool MatchAndRewrite( + pir::Operation* op, + pir::PatternRewriter& rewriter) const override { // // NOLINT + if (ctx_.steps_limit.has_value()) { + if (times_ >= ctx_.steps_limit.value()) { + return false; + } + } + const auto& ret = this->TryMatchAndRewrite(op, &rewriter); + if (ret.HasError()) { + LOG(ERROR) << "\nTraceback (most recent call last):\n" + << ret.GetError().CallStackToString() << "\n" + << ret.GetError().class_name() << ": " << ret.GetError().msg() + << "\npass_name: " << ctx_.drr_ctx->pass_name.value(); + return false; + } + bool success = ret.GetOkValue(); + if (success) { + ++times_; + } + return success; + } + + adt::Result TryMatchAndRewrite(pir::Operation* op, + pir::PatternRewriter* rewriter) const { + ADT_LET_CONST_REF(opt_match_ctx, GetMatchCtx(op)); + if (!opt_match_ctx.has_value()) { + return false; + } + ADT_CHECK(ctx_.drr_ctx->pass_name.has_value()); + LOG(ERROR) << "drr: " << ctx_.drr_ctx->pass_name.value() << " matched."; + ADT_LET_CONST_REF( + success, ap_rewriter_.Rewrite(opt_match_ctx.value(), op, rewriter)); + if (success) { + ADT_CHECK(ctx_.drr_ctx->source_pattern_ctx.has_value()); + OpEraseHelepr erase_helper{ctx_.drr_ctx->source_pattern_ctx.value()}; + ADT_RETURN_IF_ERR( + erase_helper.EraseUnusedOps(rewriter, opt_match_ctx.value())); + } + return success; + } + + adt::Result> GetMatchCtx( + pir::Operation* op) const { + return NativeOpAnchorApLowerFusionOpPatternMatcher{ctx_}.GetMatchCtx(op); + } +}; + +struct DefaultAnchorApLowerFusionOpPatternMatcher { + const ApLowerFusionOpPatternCtx& ctx_; + + using Self = DefaultAnchorApLowerFusionOpPatternMatcher; + + static adt::Result> Match(const DrrCtx& drr_ctx, + pir::Operation* op) { + ADT_LET_CONST_REF(pattern_ctx, + ApLowerFusionOpPatternCtx::MakeFromDrrCtx( + drr_ctx, + /*times_step=*/std::nullopt, + std::make_shared(drr_ctx))); + Self matcher{pattern_ctx}; + return matcher.GetMatchCtx(op); + } + + adt::Result> GetMatchCtx( + pir::Operation* op) const { + auto* parent_block = op->GetParent(); + ADT_CHECK(parent_block != nullptr); + auto* parent_op = parent_block->GetParentOp(); + ADT_CHECK(!parent_op->isa()); + const auto& default_anchor = ctx_.default_anchor; + using Default = ap::drr::topo_kind::Default; + ap::graph::GraphDescriptor pir_graph{}; + ap::graph::GraphDescriptor src_ptn_graph{}; + ap::ir_match::GraphMatcher graph_matcher( + pir_graph, src_ptn_graph); + ADT_LET_CONST_REF( + anchor_topo_cstr, + src_ptn_graph.GetSmallGraphNodeTopoCstr(default_anchor.node())); + const auto& obj_node = CastToPirNode(op); + ADT_LET_CONST_REF(topo_satisfy_constraint, + pir_graph.TopoSatisfy(obj_node, anchor_topo_cstr)); + bool satisfy_constraint = topo_satisfy_constraint; + if (satisfy_constraint) { + ap::graph::NodeDescriptor node_descriptor{}; + ADT_LET_CONST_REF(attrs_satisfy_constraint, + node_descriptor.AttrsSatisfyIfBothAreOpsOrValues( + obj_node, default_anchor.node())); + satisfy_constraint = attrs_satisfy_constraint; + } + ADT_CHECK(satisfy_constraint) << adt::errors::ValueError{ + "TopoSatisfy() or AttrsSatisfyIfBothAreOpsOrValues() test failed."}; + ADT_LET_CONST_REF( + graph_match_ctx, + graph_matcher.MatchByAnchor(obj_node, default_anchor.node())); + ADT_LET_CONST_REF( + graph_matched, + graph_matcher.IsGraphMatched(graph_match_ctx, default_anchor.node())); + if (!graph_matched) { + return std::nullopt; + } + ADT_LET_CONST_REF(constraint_matched, + ConstraintApplier{}.Match(ctx_.drr_ctx, graph_match_ctx)); + if (!constraint_matched) { + return std::nullopt; + } + return graph_match_ctx; + } + + PirNode CastToPirNode(pir::Operation* op) const { + if (op->isa()) { + ap::paddle::PackedIrOp ir_op{op->dyn_cast()}; + return ir_op; + } else { + ap::paddle::NativeIrOp ir_op{op}; + return ir_op; + } + } +}; + +class DefaultAnchorApLowerFusionOpPattern : public pir::RewritePattern { + private: + ApLowerFusionOpPatternCtx ctx_; + mutable std::size_t times_; + ApRewriter ap_rewriter_; + + public: + DefaultAnchorApLowerFusionOpPattern(pir::IrContext* ir_context, + const ApLowerFusionOpPatternCtx& ctx) + : pir::RewritePattern(ctx.anchor_op_name, 1, ir_context, {}), + ctx_(ctx), + times_(0), + ap_rewriter_(ctx, &DefaultAnchorApLowerFusionOpPatternMatcher::Match) {} + + bool MatchAndRewrite( + pir::Operation* op, + pir::PatternRewriter& rewriter) const override { // // NOLINT + if (ctx_.steps_limit.has_value()) { + if (times_ >= ctx_.steps_limit.value()) { + return false; + } + } + const auto& ret = this->TryMatchAndRewrite(op, &rewriter); + if (ret.HasError()) { + LOG(ERROR) << "\nTraceback (most recent call last):\n" + << ret.GetError().CallStackToString() << "\n" + << ret.GetError().class_name() << ": " << ret.GetError().msg() + << "\npass_name: " << ctx_.drr_ctx->pass_name.value(); + return false; + } + bool success = ret.GetOkValue(); + if (success) { + ++times_; + } + return success; + } + + adt::Result TryMatchAndRewrite(pir::Operation* op, + pir::PatternRewriter* rewriter) const { + ADT_LET_CONST_REF(opt_match_ctx, GetMatchCtx(op)); + if (!opt_match_ctx.has_value()) { + return false; + } + ADT_CHECK(ctx_.drr_ctx->pass_name.has_value()); + LOG(ERROR) << "drr: " << ctx_.drr_ctx->pass_name.value() << " matched."; + ADT_LET_CONST_REF( + success, ap_rewriter_.Rewrite(opt_match_ctx.value(), op, rewriter)); + if (success) { + ADT_CHECK(ctx_.drr_ctx->source_pattern_ctx.has_value()); + OpEraseHelepr erase_helper{ctx_.drr_ctx->source_pattern_ctx.value()}; + ADT_RETURN_IF_ERR( + erase_helper.EraseUnusedOps(rewriter, opt_match_ctx.value())); + } + return success; + } + + adt::Result> GetMatchCtx( + pir::Operation* op) const { + return DefaultAnchorApLowerFusionOpPatternMatcher{ctx_}.GetMatchCtx(op); + } +}; + +class ApLowerFusionOpPass : public pir::PatternRewritePass { + private: + std::shared_ptr drr_ctx_provider_; + std::optional steps_limit_; + + public: + explicit ApLowerFusionOpPass( + const std::shared_ptr& drr_ctx_provider, + const std::string& name, + std::optional steps_limit) + : pir::PatternRewritePass( + std::string() + "ap_lower_fusion_op_" + name + "_pass", 2), + drr_ctx_provider_(drr_ctx_provider), + steps_limit_(steps_limit) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + const auto& ret = TryInitializePatterns(&ps, context); + if (ret.HasError()) { + LOG(ERROR) << "\nTraceback (most recent call last):\n" + << ret.GetError().CallStackToString() << "\n" + << "InitializePatterns " << ret.GetError().class_name() << ": " + << ret.GetError().msg(); + } + return ps; + } + + adt::Result TryInitializePatterns(pir::RewritePatternSet* ps, + pir::IrContext* context) { + auto AddFusionOpPattern = [&](const auto& drr_ctx) -> adt::Result { + ADT_LET_CONST_REF(pattern_ctx, + ApLowerFusionOpPatternCtx::MakeFromDrrCtx( + drr_ctx, steps_limit_, drr_ctx_provider_)); + if (pattern_ctx.native_op_anchor.has_value()) { + ps->Add(std::make_unique( + context, pattern_ctx)); + } else { + ps->Add(std::make_unique( + context, pattern_ctx)); + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitEachDrrCtx(AddFusionOpPattern)); + return adt::Ok{}; + } + + template + adt::Result VisitEachDrrCtx(const YieldT& Yield) { + ADT_LET_CONST_REF(drr_ctx_list, drr_ctx_provider_->GetDrrCtxList()); + for (const auto& drr_ctx : *drr_ctx_list) { + ADT_RETURN_IF_ERR(Yield(drr_ctx)); + } + return adt::Ok{}; + } +}; + +class AbstractDrrCtxProvider : public DrrCtxProvider { + std::weak_ptr circlable_ref_list_; + + public: + explicit AbstractDrrCtxProvider( + const std::weak_ptr& circlable_ref_list) + : circlable_ref_list_(circlable_ref_list) {} + + adt::Result> GetDrrCtxList() override { + static adt::Result> drr_ctx_list(MakeDrrCtxList()); + return drr_ctx_list; + } + + adt::Result> MakeDrrCtxList() { + adt::List ret{}; + auto Collect = [&](const auto& drr_ctx) -> adt::Result { + ret->emplace_back(drr_ctx); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitEachDrrCtxByAbstractDrrPassRegistryItems(Collect)); + return ret; + } + + adt::Result PostProcess( + adt::Result> (*Match)(const DrrCtx&, + pir::Operation* op), + const DrrCtx& drr_ctx, + pir::Operation* op, + const GraphMatchCtx& match_ctx, + const std::function(const std::string&)>& + CodeGenResult4FusedOpName) override { + ap::reified_drr::ReifiedDrrPassDumpHelper dump_helper{}; + if (!dump_helper.DumpEnabled()) { + return adt::Ok{}; + } + ap::paddle::PirToAnfExprHelper attr2axpr_helper{}; + ADT_CHECK(drr_ctx->source_pattern_ctx.has_value()); + const auto& src_ptn_ctx = drr_ctx->source_pattern_ctx.value(); + ap::paddle::PirNodeMatchedSrcPtnCtxHelper src_ptn_ctx_helper(src_ptn_ctx, + match_ctx); + ADT_LET_CONST_REF( + reified_drr_pass_class_lambda_anf_expr, + dump_helper.Dump( + /*abstract_drr_ctx=*/drr_ctx, + /*attr2axpr_helper=*/&attr2axpr_helper, + /*src_ptn_ctx_helper=*/&src_ptn_ctx_helper, + /*CodeGenResult4FusedOpName=*/CodeGenResult4FusedOpName, + /*nice=*/0)); + ADT_LET_CONST_REF(reified_drr_ctx, + GetReifiedDrrCtx(drr_ctx->circlable_ref_list, + reified_drr_pass_class_lambda_anf_expr)); + ADT_LET_CONST_REF(opt_match_ctx, Match(reified_drr_ctx, op)); + ADT_CHECK(opt_match_ctx.has_value()); + return adt::Ok{}; + } + + private: + static adt::Result GetReifiedDrrCtx( + const std::weak_ptr& circlable_ref_list, + const ap::axpr::AnfExpr& reified_drr_pass_class_lambda_anf_expr) { + const auto& core_expr = ap::axpr::ConvertAnfExprToCoreExpr( + reified_drr_pass_class_lambda_anf_expr); + using CoreExpr = ap::axpr::CoreExpr; + ADT_LET_CONST_REF(atomic, + core_expr.template TryGet>()); + ADT_LET_CONST_REF(lambda, + atomic.template TryGet>()); + const auto& frames = ap::axpr::MakeBuiltinFrameAttrMap(); + ap::axpr::CpsInterpreter interpreter{frames, circlable_ref_list}; + ADT_LET_CONST_REF(drr_pass_class_val, interpreter.Interpret(lambda, {})); + ADT_LET_CONST_REF( + drr_pass_class, + drr_pass_class_val.template CastTo< + ap::axpr::TypeImpl>>()); + return ApDrrHelper{circlable_ref_list}.Interpret( + drr_pass_class.class_attrs); + } + + template + adt::Result VisitEachDrrCtxByAbstractDrrPassRegistryItems( + const YieldT& Yield) { + ADT_LET_CONST_REF(registry, ApRegistryHelper{}.SingletonRegistry()); + const auto& abstract_drr_pass_registry_items = + registry->abstract_drr_pass_registry_items; + for (const auto& [abstract_drr_pass_name, nice2abstract_drr_pass_items] : + abstract_drr_pass_registry_items) { + std::optional opt_drr_ctx; + for (const auto& [nice, abstract_drr_pass_items] : + nice2abstract_drr_pass_items) { + if (opt_drr_ctx.has_value()) { + break; + } + for (const auto& abstract_drr_pass_item : abstract_drr_pass_items) { + const auto& drr_ctx = GetDrrCtx(abstract_drr_pass_item); + if (drr_ctx.HasOkValue()) { + ADT_RETURN_IF_ERR(Yield(drr_ctx.GetOkValue())); + opt_drr_ctx = drr_ctx.GetOkValue(); + break; + } else { + LOG(ERROR) << "\nTraceback (most recent call last):\n" + << drr_ctx.GetError().CallStackToString() << "\n" + << drr_ctx.GetError().class_name() + << ": abstract_drr_pass_name: " << abstract_drr_pass_name + << " nice: " << nice + << " msg: " << drr_ctx.GetError().msg(); + } + } + } + } + return adt::Ok{}; + } + + adt::Result GetDrrCtx( + const ap::registry::AbstractDrrPassRegistryItem& abstract_drr_pass_item) { + ADT_LET_CONST_REF(drr_ctx, + ApDrrHelper{circlable_ref_list_}.Interpret( + abstract_drr_pass_item->cls)); + if (!drr_ctx->pass_name.has_value()) { + drr_ctx.shared_ptr()->pass_name = + abstract_drr_pass_item->abstract_drr_pass_name; + } + return drr_ctx; + } +}; + +class ClassicDrrCtxProvider : public DrrCtxProvider { + std::weak_ptr circlable_ref_list_; + + public: + explicit ClassicDrrCtxProvider( + const std::weak_ptr& circlable_ref_list) + : circlable_ref_list_(circlable_ref_list) {} + + adt::Result> GetDrrCtxList() override { + static adt::Result> drr_ctx_list(MakeDrrCtxList()); + return drr_ctx_list; + } + + adt::Result> MakeDrrCtxList() { + adt::List ret{}; + auto Collect = [&](const auto& drr_ctx) -> adt::Result { + ret->emplace_back(drr_ctx); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitEachDrrCtxByClassicDrrPassRegistryItems(Collect)); + return ret; + } + + adt::Result PostProcess( + adt::Result> (*Match)(const DrrCtx&, + pir::Operation* op), + const DrrCtx& drr_ctx, + pir::Operation* op, + const GraphMatchCtx& match_ctx, + const std::function(const std::string&)>& + CodeGenResult4FusedOpName) override { + // Do nothing. + return adt::Ok{}; + } + + private: + template + adt::Result VisitEachDrrCtxByClassicDrrPassRegistryItems( + const YieldT& Yield) { + ADT_LET_CONST_REF(registry, ApRegistryHelper{}.SingletonRegistry()); + const auto& classic_drr_pass_registry_items = + registry->classic_drr_pass_registry_items; + for (const auto& [classic_drr_pass_name, nice2classic_drr_pass_items] : + classic_drr_pass_registry_items) { + std::optional opt_drr_ctx; + for (const auto& [nice, classic_drr_pass_items] : + nice2classic_drr_pass_items) { + if (opt_drr_ctx.has_value()) { + break; + } + for (const auto& classic_drr_pass_item : classic_drr_pass_items) { + const auto& drr_ctx = GetDrrCtx(classic_drr_pass_item); + if (drr_ctx.HasOkValue()) { + ADT_RETURN_IF_ERR(Yield(drr_ctx.GetOkValue())); + opt_drr_ctx = drr_ctx.GetOkValue(); + break; + } else { + LOG(ERROR) << "\nTraceback (most recent call last):\n" + << drr_ctx.GetError().CallStackToString() << "\n" + << drr_ctx.GetError().class_name() + << ": classic_drr_pass_name: " << classic_drr_pass_name + << " nice: " << nice + << " msg: " << drr_ctx.GetError().msg(); + } + } + } + } + return adt::Ok{}; + } + + adt::Result GetDrrCtx( + const ap::registry::ClassicDrrPassRegistryItem& classic_drr_pass_item) { + ADT_LET_CONST_REF( + drr_ctx, + ApDrrHelper{circlable_ref_list_}.Interpret(classic_drr_pass_item->cls)); + if (!drr_ctx->pass_name.has_value()) { + drr_ctx.shared_ptr()->pass_name = + classic_drr_pass_item->classic_drr_pass_name; + } + return drr_ctx; + } +}; + +class TagAccessTopoDrrCtxProvider : public DrrCtxProvider { + std::weak_ptr circlable_ref_list_; + std::string pass_tag_name_; + + public: + explicit TagAccessTopoDrrCtxProvider( + const std::weak_ptr& circlable_ref_list, + const std::string& pass_tag_name) + : circlable_ref_list_(circlable_ref_list), + pass_tag_name_(pass_tag_name) {} + + adt::Result> GetDrrCtxList() override { + adt::Result> drr_ctx_list(MakeDrrCtxList(pass_tag_name_)); + return drr_ctx_list; + } + + adt::Result> MakeDrrCtxList( + const std::string& pass_tag_name) { + adt::List ret{}; + auto Collect = [&](const auto& drr_ctx) -> adt::Result { + ret->emplace_back(drr_ctx); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitEachDrrCtxByAccessTopoDrrPassRegistryItems( + pass_tag_name, Collect)); + return ret; + } + + adt::Result PostProcess( + adt::Result> (*Match)(const DrrCtx&, + pir::Operation* op), + const DrrCtx& drr_ctx, + pir::Operation* op, + const GraphMatchCtx& match_ctx, + const std::function(const std::string&)>& + CodeGenResult4FusedOpName) override { + // Do nothing. + return adt::Ok{}; + } + + private: + template + adt::Result VisitEachDrrCtxByAccessTopoDrrPassRegistryItems( + const std::string& pass_tag_name, const YieldT& Yield) { + ADT_LET_CONST_REF(registry, ApRegistryHelper{}.SingletonRegistry()); + const auto& access_topo_drr_pass_registry_items = + registry->access_topo_drr_pass_registry_items; + for (const auto& [access_topo_drr_pass_name, + nice2access_topo_drr_pass_items] : + access_topo_drr_pass_registry_items) { + std::optional opt_drr_ctx; + for (const auto& [nice, access_topo_drr_pass_items] : + nice2access_topo_drr_pass_items) { + if (opt_drr_ctx.has_value()) { + break; + } + for (const auto& access_topo_drr_pass_item : + access_topo_drr_pass_items) { + if (pass_tag_name != access_topo_drr_pass_item->pass_tag_name) { + continue; + } + const auto& drr_ctx = GetDrrCtx(access_topo_drr_pass_item); + if (drr_ctx.HasOkValue()) { + ADT_RETURN_IF_ERR(Yield(drr_ctx.GetOkValue())); + opt_drr_ctx = drr_ctx.GetOkValue(); + break; + } else { + LOG(ERROR) << "\nTraceback (most recent call last):\n" + << drr_ctx.GetError().CallStackToString() << "\n" + << drr_ctx.GetError().class_name() + << ": access_topo_drr_pass_name: " + << access_topo_drr_pass_name << " nice: " << nice + << " msg: " << drr_ctx.GetError().msg(); + } + } + } + } + return adt::Ok{}; + } + + adt::Result GetDrrCtx( + const ap::registry::AccessTopoDrrPassRegistryItem& + access_topo_drr_pass_item) { + ADT_LET_CONST_REF(drr_ctx, + ApDrrHelper{circlable_ref_list_}.Interpret( + access_topo_drr_pass_item->cls)); + if (!drr_ctx->pass_name.has_value()) { + drr_ctx.shared_ptr()->pass_name = + access_topo_drr_pass_item->access_topo_drr_pass_name; + } + return drr_ctx; + } +}; + +class CustomAccessTopoDrrCtxProvider : public DrrCtxProvider { + std::weak_ptr circlable_ref_list_; + ap::axpr::Value drr_pass_obj_; + ap::axpr::Value mut_matched_pattern_as_programs_; + std::size_t seq_no_; + + public: + explicit CustomAccessTopoDrrCtxProvider( + const std::weak_ptr& circlable_ref_list, + const ap::axpr::Value& drr_pass_obj, + const ap::axpr::Value& mut_matched_pattern_as_programs) + : circlable_ref_list_(circlable_ref_list), + drr_pass_obj_(drr_pass_obj), + mut_matched_pattern_as_programs_(mut_matched_pattern_as_programs), + seq_no_(0) {} + + adt::Result> GetDrrCtxList() override { + adt::Result> drr_ctx_list(MakeDrrCtxList()); + return drr_ctx_list; + } + + adt::Result> MakeDrrCtxList() { + adt::List ret{}; + auto Collect = [&](const auto& drr_ctx) -> adt::Result { + ret->emplace_back(drr_ctx); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitEachCreatedDrrCtx(Collect)); + return ret; + } + + adt::Result PostProcess( + adt::Result> (*Match)(const DrrCtx&, + pir::Operation* op), + const DrrCtx& drr_ctx, + pir::Operation* op, + const GraphMatchCtx& match_ctx, + const std::function(const std::string&)>& + CodeGenResult4FusedOpName) override { + if (mut_matched_pattern_as_programs_.template CastableTo()) { + return adt::Ok{}; + } + ADT_LET_CONST_REF( + mut_lst, + mut_matched_pattern_as_programs_ + .template CastTo>()); + ADT_CHECK(drr_ctx->source_pattern_ctx.has_value()); + ADT_LET_CONST_REF(program, + CopyMatchedPatternToProgram( + drr_ctx->source_pattern_ctx.value(), match_ctx)); + ADT_LET_CONST_REF(mut_list_ptr, mut_lst.Mut()); + mut_list_ptr->emplace_back(ap::paddle::GetPirProgramClass().New(program)); + return adt::Ok{}; + } + + private: + adt::Result CopyMatchedPatternToProgram( + const ap::drr::SourcePatternCtx& source_pattern_ctx, + const GraphMatchCtx& match_ctx) { + pir::IrContext* ctx = ::pir::IrContext::Instance(); + auto new_program = std::make_shared<::pir::Program>(ctx); + auto clone_options = ::pir::CloneOptions::All(); + pir::IrMapping ir_mapping; + pir::Builder builder(ctx, new_program->block()); + auto DoEachInput = [&](const auto& name, + pir::Value value) -> adt::Result { + ADT_CHECK(value.type().isa()); + const auto& type = value.type().dyn_cast(); + const auto& dims = ::common::vectorize(type.dims()); + auto phi_type = ::paddle::dialect::TransToPhiDataType(type.dtype()); + auto op = builder.Build<::paddle::dialect::DataOp>( + name, dims, phi_type, phi::Place()); + ir_mapping.Add(value, op->result(0)); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + VisitMatchedInput(source_pattern_ctx, match_ctx, DoEachInput)); + std::optional old_program{}; + auto DoEachOp = [&](pir::Operation* op) -> adt::Result { + if (old_program.has_value()) { + ADT_CHECK(old_program.value() == op->GetParentProgram()); + } else { + old_program = op->GetParentProgram(); + } + auto* new_op = op->Clone(ir_mapping, clone_options); + new_program->block()->push_back(new_op); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitMatchedOp(source_pattern_ctx, match_ctx, DoEachOp)); + if (old_program.has_value()) { + ADT_RETURN_IF_ERR(CloneSymbolicShapes( + new_program.get(), old_program.value(), ir_mapping)); + } + return ap::paddle::Program{new_program}; + } + + adt::Result CloneSymbolicShapes(pir::Program* new_program, + pir::Program* old_program, + const pir::IrMapping& ir_mapping) { + auto* new_shape_analysis = + &::pir::ShapeAnalysisManager::Instance().Get(new_program); + auto* old_shape_analysis = + &::pir::ShapeAnalysisManager::Instance().Get(old_program); + for (const auto& [old_value, new_value] : ir_mapping.GetMap()) { + new_shape_analysis->SetShapeOrDataForValue( + new_value, old_shape_analysis->GetShapeOrDataForValue(old_value)); + } + return adt::Ok{}; + } + + template + adt::Result VisitMatchedOp( + const ap::drr::SourcePatternCtx& source_pattern_ctx, + const GraphMatchCtx& match_ctx, + const YieldT& Yield) { + using Ok = adt::Result; + auto DoEachOp = [&](const auto& drr_graph_node) -> Ok { + ADT_LET_CONST_REF(drr_node, drr_graph_node.Get()); + const auto& drr_op = GetDrrOp(drr_node); + if (!drr_op.has_value()) return adt::Ok{}; + ADT_LET_CONST_REF(pir_node, + match_ctx->GetSoleBigGraphNode(drr_graph_node)); + const auto& pir_op = GetPirOp(pir_node); + if (pir_op.has_value()) { + ADT_RETURN_IF_ERR(Yield(pir_op.value())); + } + return adt::Ok{}; + }; + return VisitEachGraphNode(source_pattern_ctx, DoEachOp); + } + + template + adt::Result VisitEachGraphNode( + const ap::drr::SourcePatternCtx& source_pattern_ctx, + const YieldT& Yield) { + std::list sources; + for (const auto& drr_node : source_pattern_ctx->node_arena->nodes()) { + const auto& drr_graph_node = drr_node.node(); + ADT_LET_CONST_REF(upstreams, drr_graph_node.UpstreamNodes()); + if (upstreams.size() == 0) { + sources.push_back(drr_graph_node); + } + } + using Ok = adt::Result; + ap::drr::DefaultDrrGraphDescriptor graph{}; + auto VisitPrev = [&](const DrrGraphNode& node, const auto& Yield) -> Ok { + return graph.VisitUpstreamNodes(node, Yield); + }; + auto VisitNext = [&](const DrrGraphNode& node, const auto& Yield) -> Ok { + return graph.VisitDownstreamNodes(node, Yield); + }; + ap::adt::TopoWalker walker{VisitPrev, VisitNext}; + ADT_RETURN_IF_ERR(walker(sources.begin(), sources.end(), Yield)); + return adt::Ok{}; + } + + std::optional GetDrrOp(const DrrNode& drr_node) const { + return drr_node.Match( + [&](const DrrNativeIrOp& ir_op) -> std::optional { + return DrrIrOp{ir_op}; + }, + [&](const DrrPackedIrOp& ir_op) -> std::optional { + return DrrIrOp{ir_op}; + }, + [&](const auto&) -> std::optional { return std::nullopt; }); + } + + std::optional GetPirOp(const PirNode& pir_node) const { + return pir_node.Match( + [&](const ap::paddle::NativeIrOp& ir_op) + -> std::optional { return ir_op.op; }, + [&](const ap::paddle::PackedIrOp& ir_op) + -> std::optional { return ir_op.fusion_op; }, + [&](const auto&) -> std::optional { + return std::nullopt; + }); + } + + template + adt::Result VisitMatchedInput( + const ap::drr::SourcePatternCtx& source_pattern_ctx, + const GraphMatchCtx& match_ctx, + const YieldT& Yield) { + for (const auto& drr_node : source_pattern_ctx->node_arena->nodes()) { + const auto& drr_graph_node = drr_node.node(); + ADT_LET_CONST_REF(upstreams, drr_graph_node.UpstreamNodes()); + if (upstreams.size() > 0) { + continue; + } + ADT_RETURN_IF_ERR(drr_node.Match( + [&](const DrrNativeIrValue& impl) -> adt::Result { + return VisitMatchedNativeIrValueInput(impl, match_ctx, Yield); + }, + [&](const DrrPackedIrValue& impl) -> adt::Result { + return VisitMatchedPackedIrValueInput(impl, match_ctx, Yield); + }, + [&](const DrrNativeIrOp&) -> adt::Result { + // Do nothing. + return adt::Ok{}; + }, + [&](const DrrPackedIrOp&) -> adt::Result { + // Do nothing. + return adt::Ok{}; + }, + [&](const DrrOptPackedIrOp&) -> adt::Result { + // Do nothing. + return adt::Ok{}; + }, + [&](const auto&) -> adt::Result { + return adt::errors::NotImplementedError{ + "VisitMatchedInput() failed"}; + })); + } + return adt::Ok{}; + } + + template + adt::Result VisitMatchedNativeIrValueInput( + const DrrNativeIrValue& drr_value, + const GraphMatchCtx& match_ctx, + const YieldT& Yield) { + auto DoEach = [&](const PirNode& pir_node) -> adt::Result { + ADT_LET_CONST_REF(native_ir_value, + pir_node.template TryGet()); + return Yield(drr_value->name, native_ir_value.value); + }; + return match_ctx->VisitBigGraphIrValueNode(drr_value->node, DoEach); + } + + template + adt::Result VisitMatchedPackedIrValueInput( + const DrrPackedIrValue& drr_value, + const GraphMatchCtx& match_ctx, + const YieldT& Yield) { + int i = 0; + auto DoEach = [&](const PirNode& pir_node) -> adt::Result { + ADT_LET_CONST_REF(native_ir_value, + pir_node.template TryGet()); + const auto& name = drr_value->name + "[" + std::to_string(i++) + "]"; + return Yield(drr_value->name, native_ir_value.value); + }; + return match_ctx->VisitBigGraphIrValueNode(drr_value->node, DoEach); + } + + template + adt::Result VisitEachCreatedDrrCtx(const YieldT& Yield) { + using AList = ap::axpr::AbstractList; + if (AList::CastableFrom(drr_pass_obj_)) { + auto DoEach = + [&](const auto& drr_pass_obj) -> adt::Result { + ADT_LET_CONST_REF(drr_ctx, GetDrrCtx(drr_pass_obj)); + ADT_RETURN_IF_ERR(Yield(drr_ctx)); + return adt::Continue{}; + }; + ADT_LET_CONST_REF(lst, AList::CastFrom(drr_pass_obj_)); + ADT_RETURN_IF_ERR(lst.Visit(DoEach)); + } else { + ADT_LET_CONST_REF(drr_ctx, GetDrrCtx(drr_pass_obj_)); + ADT_RETURN_IF_ERR(Yield(drr_ctx)); + } + return adt::Ok{}; + } + + adt::Result GetDrrCtx(const ap::axpr::Value& drr_pass_obj) { + ApDrrHelper helper{circlable_ref_list_}; + ADT_LET_CONST_REF(drr_ctx, helper.CreateDrrCtxByDrrPassObj(drr_pass_obj)); + if (!drr_ctx->pass_name.has_value()) { + drr_ctx.shared_ptr()->pass_name = + std::string("tmp_access_drr_pass_") + std::to_string(seq_no_++); + } + return drr_ctx; + } +}; + +adt::Result TryGetRegistrySingleton() { + ap::paddle::ForceLinkPir(); + ap::paddle::ForceLinkIrTools(); + ADT_LET_CONST_REF(registry, ApRegistryHelper{}.SingletonRegistry()); + return registry; +} + +std::optional GetRegistrySingleton() { + const auto& registry = TryGetRegistrySingleton(); + if (registry.HasOkValue()) { + return registry.GetOkValue(); + } else { + LOG(ERROR) << "\nTraceback (most recent call last):\n" + << registry.GetError().CallStackToString() << "\n" + << registry.GetError().class_name() << ": " + << registry.GetError().msg(); + return std::nullopt; + } +} + +} // namespace + +std::optional> +CreateApLowerFusionOpAbstractDrrPass( + const std::weak_ptr& circlable_ref_list) { + if (!GetRegistrySingleton().has_value()) { + return std::nullopt; + } + auto drr_ctx_provider = + std::make_shared(circlable_ref_list); + const auto& drr_ctx_list = drr_ctx_provider->GetDrrCtxList(); + if (drr_ctx_list.HasError()) { + LOG(ERROR) << "\nTraceback (most recent call last):\n" + << drr_ctx_list.GetError().CallStackToString() << "\n" + << drr_ctx_list.GetError().class_name() << ": " + << drr_ctx_list.GetError().msg(); + return std::nullopt; + } + if (drr_ctx_list.GetOkValue()->empty()) { + return std::nullopt; + } + std::unique_ptr<::pir::Pass> pass = + std::make_unique(drr_ctx_provider, + /*name=*/"abstract", + /*steps_limit=*/std::nullopt); + return std::move(pass); +} + +std::optional> CreateApLowerFusionOpClassicDrrPass( + const std::weak_ptr& circlable_ref_list) { + if (!GetRegistrySingleton().has_value()) { + return std::nullopt; + } + auto drr_ctx_provider = + std::make_shared(circlable_ref_list); + const auto& drr_ctx_list = drr_ctx_provider->GetDrrCtxList(); + if (drr_ctx_list.HasError()) { + LOG(ERROR) << "\nTraceback (most recent call last):\n" + << drr_ctx_list.GetError().CallStackToString() << "\n" + << drr_ctx_list.GetError().class_name() << ": " + << drr_ctx_list.GetError().msg(); + return std::nullopt; + } + if (drr_ctx_list.GetOkValue()->empty()) { + return std::nullopt; + } + std::unique_ptr<::pir::Pass> pass = + std::make_unique(drr_ctx_provider, + /*name=*/"classic", + /*steps_limit=*/std::nullopt); + return std::move(pass); +} + +std::optional> CreateAccessTopoDrrPass( + const std::weak_ptr& circlable_ref_list, + const std::string& drr_pass_tag, + std::optional steps_limit) { + if (!GetRegistrySingleton().has_value()) { + return std::nullopt; + } + auto drr_ctx_provider = std::make_shared( + circlable_ref_list, drr_pass_tag); + const auto& drr_ctx_list = drr_ctx_provider->GetDrrCtxList(); + if (drr_ctx_list.HasError()) { + LOG(ERROR) << "\nTraceback (most recent call last):\n" + << drr_ctx_list.GetError().CallStackToString() << "\n" + << drr_ctx_list.GetError().class_name() << ": " + << drr_ctx_list.GetError().msg(); + return std::nullopt; + } + if (drr_ctx_list.GetOkValue()->empty()) { + return std::nullopt; + } + std::unique_ptr<::pir::Pass> pass = + std::make_unique(drr_ctx_provider, + /*name=*/"tag_access_topo", + /*steps_limit=*/steps_limit); + return std::move(pass); +} + +std::optional> CreateCustomAccessTopoDrrPass( + const std::weak_ptr& circlable_ref_list, + const ap::axpr::Value& drr_pass_obj, + std::optional steps_limit, + const ap::axpr::Value& mut_matched_pattern_as_programs) { + auto drr_ctx_provider = std::make_shared( + circlable_ref_list, drr_pass_obj, mut_matched_pattern_as_programs); + const auto& drr_ctx_list = drr_ctx_provider->GetDrrCtxList(); + if (drr_ctx_list.HasError()) { + LOG(ERROR) << "\nTraceback (most recent call last):\n" + << drr_ctx_list.GetError().CallStackToString() << "\n" + << drr_ctx_list.GetError().class_name() << ": " + << drr_ctx_list.GetError().msg(); + return std::nullopt; + } + if (drr_ctx_list.GetOkValue()->empty()) { + return std::nullopt; + } + std::unique_ptr<::pir::Pass> pass = + std::make_unique(drr_ctx_provider, + /*name=*/"custom_access_topo", + /*steps_limit=*/steps_limit); + return std::move(pass); +} + +} // namespace cinn::dialect::ir diff --git a/paddle/ap/src/paddle/pass/ap_registry_helper.cc b/paddle/ap/src/paddle/pass/ap_registry_helper.cc new file mode 100644 index 00000000000000..fccbf3cdf5f336 --- /dev/null +++ b/paddle/ap/src/paddle/pass/ap_registry_helper.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/paddle/pass/ap_registry_helper.h" +#include "paddle/ap/include/registry/registry_mgr.h" + +namespace cinn::dialect::ir { + +namespace { + +using ap::registry::Registry; +using ap::registry::RegistryMgr; +using ap::registry::RegistrySingleton; + +} // namespace + +ap::adt::Result ApRegistryHelper::SingletonRegistry() { + ADT_RETURN_IF_ERR(RegistryMgr::Singleton()->LoadAllOnce()); + ADT_LET_CONST_REF(registry, RegistrySingleton::Singleton()); + return registry; +} + +} // namespace cinn::dialect::ir diff --git a/paddle/ap/src/paddle/pass/ir_helper_method_class.cc b/paddle/ap/src/paddle/pass/ir_helper_method_class.cc new file mode 100644 index 00000000000000..c23d5100e9d275 --- /dev/null +++ b/paddle/ap/src/paddle/pass/ir_helper_method_class.cc @@ -0,0 +1,390 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/paddle/pass/ir_helper_method_class.h" +#include "paddle/ap/include/axpr/module_mgr.h" +#include "paddle/ap/include/axpr/to_string.h" +#include "paddle/ap/include/drr/drr_graph_descriptor.h" +#include "paddle/ap/include/drr/drr_node_descriptor.h" +#include "paddle/ap/include/paddle/pir_graph_descriptor.h" +#include "paddle/ap/include/paddle/pir_node_descriptor.h" + +namespace ap::paddle { + +struct PirHelperMethodClass { + using This = PirHelperMethodClass; + using GraphMatchCtx = ir_match::GraphMatchCtx; + + static adt::Result CreatePassManager( + const axpr::Value&, const std::vector& args) { + ADT_CHECK(args.size() == 0); + auto* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + PassManager pass_manager{std::make_shared<::pir::PassManager>(ctx, 3)}; + return GetPirPassManagerClass().New(pass_manager); + } + + static adt::Result CreateAccessTopoDrrPass( + axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "create_ap_drr_pass() takes 1 arguments, but " + + std::to_string(args.size()) + " were given"}; + std::optional> opt_pass; + if (args.at(0).template CastableTo()) { + ADT_LET_CONST_REF(drr_pass_tag_name, + args.at(0).template CastTo()); + opt_pass = cinn::dialect::ir::CreateAccessTopoDrrPass( + interpreter->circlable_ref_list(), + drr_pass_tag_name, + /*steps_limit=*/std::nullopt); + } else { + opt_pass = cinn::dialect::ir::CreateCustomAccessTopoDrrPass( + interpreter->circlable_ref_list(), + args.at(0), + /*steps_limit=*/std::nullopt, + /*mut_matched_pattern_as_programs=*/adt::Nothing{}); + } + if (!opt_pass.has_value()) { + return adt::Nothing{}; + } + Pass pass{std::move(opt_pass.value())}; + return GetPirPassClass().New(pass); + } + + static adt::Result CreateAccessTopoDrrOneStepPass( + axpr::InterpreterBase* interpreter, + const axpr::Value&, + const std::vector& packed_args_val) { + const auto [args, kwargs] = *axpr::CastToPackedArgs(packed_args_val); + ADT_CHECK(args->size() == 1) << adt::errors::TypeError{ + std::string() + "create_ap_drr_pass() takes 1 arguments, but " + + std::to_string(args->size()) + " were given"}; + std::optional> opt_pass; + if (args->at(0).template CastableTo()) { + ADT_LET_CONST_REF(drr_pass_tag_name, + args->at(0).template CastTo()); + opt_pass = cinn::dialect::ir::CreateAccessTopoDrrPass( + interpreter->circlable_ref_list(), + drr_pass_tag_name, + /*steps_limit=*/1); + } else { + std::optional matched_pattern_mut_list{ + kwargs->OptGet("matched_pattern_mut_list")}; + if (!matched_pattern_mut_list.has_value()) { + matched_pattern_mut_list = adt::Nothing{}; + } + opt_pass = cinn::dialect::ir::CreateCustomAccessTopoDrrPass( + interpreter->circlable_ref_list(), + args->at(0), + /*steps_limit=*/1, + /*mut_matched_pattern_as_programs=*/matched_pattern_mut_list.value()); + } + if (!opt_pass.has_value()) { + return adt::Nothing{}; + } + Pass pass{std::move(opt_pass.value())}; + return GetPirPassClass().New(pass); + } + + static adt::Result CreateDeadCodeEliminationPass( + const axpr::Value&, const std::vector& args) { + ADT_CHECK(args.size() == 0) << adt::errors::TypeError{ + std::string() + "create_dce_pass() takes 0 arguments, but " + + std::to_string(args.size()) + " were given"}; + Pass pass{pir::CreateDeadCodeEliminationPass()}; + return GetPirPassClass().New(pass); + } + + static adt::Result CopyFusedOpsToProgram( + const axpr::Value&, const std::vector& packed_args_val) { + const auto [args, kwargs] = *axpr::CastToPackedArgs(packed_args_val); + ADT_CHECK(args->size() == 1) << adt::errors::TypeError{ + std::string() + "copy_fused_ops_to_program() takes 1 arguments, but " + + std::to_string(args->size()) + " were given"}; + ADT_LET_CONST_REF(pir_node, PirNodeHelper{}.CastFromAxprValue(args->at(0))) + << adt::errors::TypeError{ + std::string() + + "the first argument of copy_fused_ops_to_program() must be a " + "PackedIrOp/OptPackedIrOp (not " + + axpr::GetTypeName(args->at(0)) + ")"}; + ADT_LET_CONST_REF(tensor_match_ctx_val, kwargs->Get("tensor_match_ctx")) + << adt::errors::TypeError{ + std::string() + + "copy_fused_ops_to_program() need keyword argument " + "'tensor_match_ctx' of 'TensorMatchCtx' type "}; + ADT_LET_CONST_REF(tensor_match_ctx, + tensor_match_ctx_val + .template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "copy_fused_ops_to_program() need keyword argument " + "'tensor_match_ctx' of 'TensorMatchCtx' type "}; + std::unordered_map map; + ADT_RETURN_IF_ERR( + This{}.InitPirValue2Name(&map, tensor_match_ctx, pir_node)); + auto NameGetter = [&](pir::Value value) -> adt::Result { + const auto& iter = map.find(value); + ADT_CHECK(iter != map.end()); + return &iter->second; + }; + using RetT = adt::Result; + return pir_node.Match( + [&](const PackedIrOp& packed_ir_op) -> RetT { + return This{}.CopyPackedIrOpBlockToProgram(packed_ir_op, NameGetter); + }, + [&](const auto&) -> RetT { + return adt::errors::TypeError{ + std::string() + + "the first argument of copy_fused_ops_to_program() must be a " + "PackedIrOp (not " + + axpr::GetTypeName(args->at(0)) + ")"}; + }); + } + + adt::Result InitPirValue2Name( + std::unordered_map* map, + const ir_match::TensorMatchCtx& tensor_match_ctx, + const PirNode& pir_node) { + ADT_LET_CONST_REF(ir_match_ctx, + adt::WeakPtrLock(tensor_match_ctx->ir_mtach_ctx)); + const auto& graph_match_ctx = ir_match_ctx->graph_match_ctx; + ADT_LET_CONST_REF(drr_graph_node, + graph_match_ctx->GetMatchedSmallGraphNode(pir_node)); + ADT_LET_CONST_REF(drr_node, drr_graph_node.Get()); + ADT_RETURN_IF_ERR(CheckIsOpNode(drr_node)); + using Ok = adt::Result; + auto DoEachNameAndIrValue = [&](const std::string& name, + pir::Value val) -> Ok { + if (!map->emplace(val, name).second) { + ADT_CHECK(map->at(val) == name) << adt::errors::ValueError{ + std::string() + "InitPirValue2Name() failed. old_name: " + + map->at(val) + ", new_name: " + name}; + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitInputNameAndPirValue( + graph_match_ctx, drr_node, DoEachNameAndIrValue)); + return adt::Ok{}; + } + + template + adt::Result VisitInputNameAndPirValue( + const GraphMatchCtx& graph_match_ctx, + const drr::Node& drr_node, + const YieldT& Yield) { + using Ok = adt::Result; + auto DoEach = [&](const auto& drr_graph_node) -> Ok { + ADT_LET_CONST_REF(upstreams_of_upstream, drr_graph_node.UpstreamNodes()); + ADT_LET_CONST_REF(input_graph_node, upstreams_of_upstream.Sole()); + ADT_LET_CONST_REF(input_drr_node, input_graph_node.Get()); + ADT_RETURN_IF_ERR( + VisitNameAndPirValue(graph_match_ctx, input_drr_node, Yield)); + return adt::Ok{}; + }; + ADT_LET_CONST_REF(upstreams, drr_node.node().UpstreamNodes()); + ADT_RETURN_IF_ERR(upstreams.VisitNodes(DoEach)); + return adt::Ok{}; + } + + template + adt::Result VisitNameAndPirValue( + const GraphMatchCtx& graph_match_ctx, + const drr::Node& drr_node, + const YieldT& Yield) { + using Ok = adt::Result; + return drr_node.Match( + [&](const drr::NativeIrValue& impl) -> Ok { + const auto& node = impl->node; + std::size_t i = 0; + auto DoEach = [&](const PirNode& pir_node) -> Ok { + ADT_LET_CONST_REF( + native_ir_value, + pir_node.template TryGet()); + ADT_CHECK(i++ == 0); + ADT_RETURN_IF_ERR(Yield(impl->name, native_ir_value.value)); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + graph_match_ctx->VisitBigGraphIrValueNode(node, DoEach)); + return adt::Ok{}; + }, + [&](const drr::PackedIrValue& impl) -> Ok { + const auto& node = impl->node; + std::size_t i = 0; + auto DoEach = [&](const PirNode& pir_node) -> Ok { + ADT_LET_CONST_REF( + native_ir_value, + pir_node.template TryGet()); + const auto& name = impl->name + "[" + std::to_string(i++) + "]"; + ADT_RETURN_IF_ERR(Yield(name, native_ir_value.value)); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + graph_match_ctx->VisitBigGraphIrValueNode(node, DoEach)); + return adt::Ok{}; + }, + [&](const auto&) -> Ok { + return adt::errors::TypeError{ + "copy_fused_ops_to_program() failed. the inputs of DrrPackedIrOp " + "should be a DrrNativeIrValue or DrrPackedIrValue"}; + }); + } + + adt::Result CheckIsOpNode(const drr::Node& drr_node) { + using Ok = adt::Result; + return drr_node.Match( + [&](const drr::PackedIrOp&) -> Ok { return adt::Ok{}; }, + [&](const drr::OptPackedIrOp&) -> Ok { return adt::Ok{}; }, + [&](const auto& impl) -> Ok { + return adt::errors::TypeError{ + std::string() + + "the argument 1 of ir_helper.copy_fused_ops_to_program() should " + "be a PackedIrOp/RefIrOp"}; + }); + } + + template + adt::Result CopyPackedIrOpBlockToProgram( + const PackedIrOp& packed_ir_op, const NameGetterT& NameGetter) { + auto* block = packed_ir_op.fusion_op.block(); + pir::IrContext* ctx = ::pir::IrContext::Instance(); + auto new_program = std::make_shared<::pir::Program>(ctx); + auto clone_options = ::pir::CloneOptions::All(); + pir::IrMapping ir_mapping; + ADT_RETURN_IF_ERR(InitIrMapping(NameGetter, + pir::GetUsedExternalValue(*block), + &ir_mapping, + new_program->block())); + for (const auto& op : *block) { + auto* new_op = op.Clone(ir_mapping, clone_options); + new_program->block()->push_back(new_op); + } + ADT_RETURN_IF_ERR( + CloneSymbolicShapes(packed_ir_op.fusion_op->GetParentProgram(), + new_program.get(), + ir_mapping)); + Program ap_program{new_program}; + return GetPirProgramClass().New(ap_program); + } + + adt::Result CloneSymbolicShapes(pir::Program* new_program, + pir::Program* old_program, + const pir::IrMapping& ir_mapping) { + auto* new_shape_analysis = + &::pir::ShapeAnalysisManager::Instance().Get(new_program); + auto* old_shape_analysis = + &::pir::ShapeAnalysisManager::Instance().Get(old_program); + for (const auto& [old_value, new_value] : ir_mapping.GetMap()) { + new_shape_analysis->SetShapeOrDataForValue( + new_value, old_shape_analysis->GetShapeOrDataForValue(old_value)); + } + return adt::Ok{}; + } + + template + adt::Result InitIrMapping(const NameGetterT& NameGetter, + const std::vector& free_values, + pir::IrMapping* ir_mapping, + pir::Block* block) { + int i = 0; + pir::Builder builder(pir::IrContext::Instance(), block); + for (const auto& free_value : free_values) { + ADT_LET_CONST_REF(name, NameGetter(free_value)); + ADT_CHECK(free_value.type().isa()); + const auto& type = free_value.type().dyn_cast(); + const auto& dims = ::common::vectorize(type.dims()); + auto phi_type = ::paddle::dialect::TransToPhiDataType(type.dtype()); + auto op = builder.Build<::paddle::dialect::DataOp>( + *name, dims, phi_type, phi::Place()); + ir_mapping->Add(free_value, op->result(0)); + } + return adt::Ok{}; + } + + static adt::Result Match( + axpr::InterpreterBase* interpreter, + const axpr::Value& self_val, + const std::vector& args) { + ADT_CHECK(args.size() == 2) << adt::errors::TypeError{ + std::string() + "PirHelper.match() takes 2 arguments, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(program, args.at(0).template CastTo()) + << adt::errors::TypeError{std::string() + + "the argument 1 of PirHelper.match() should " + "b a PirProgram (not " + + axpr::GetTypeName(args.at(0)) + ")"}; + ADT_CHECK(axpr::CallableHelper{}.IsCallable(args.at(1))) + << adt::errors::TypeError{std::string() + + "the argument 2 of PirHelper.match() should " + "be callable object (not " + + axpr::GetTypeName(args.at(1)) + ")"}; + std::vector src_ptn_func_args{std::string("fake_pass"), + args.at(1)}; + ADT_LET_CONST_REF(lambda, This{}.GetDrrCtxMaker()); + axpr::Function function{lambda, std::nullopt}; + ADT_LET_CONST_REF( + drr_ctx, + cinn::dialect::ir::ApDrrHelper{interpreter->circlable_ref_list()} + .InterpretDrrCtxMaker(function, src_ptn_func_args)); + ADT_CHECK(drr_ctx->source_pattern_ctx.has_value()); + ap::paddle::PackedIrOpInnerSourcePatternHelper src_pattern_helper{}; + ADT_LET_CONST_REF( + opt_graph_match_ctx, + src_pattern_helper.Match(program->pir_program->block(), + drr_ctx->source_pattern_ctx.value())); + return opt_graph_match_ctx.has_value(); + } + + adt::Result> GetDrrCtxMaker() { + using LambdaT = adt::Result>; + static LambdaT lambda([]() -> LambdaT { + auto GetBody = [&](auto& ctx) -> axpr::AnfExpr { + auto& drr_ctx = ctx.Var("DrrCtx").Call(); + drr_ctx.Attr("init_pass_name").Call(ctx.Var("pass_name")); + drr_ctx.Attr("init_source_pattern").Call(ctx.Var("src_ptn_func")); + return drr_ctx; + }; + axpr::LambdaExprBuilder lmbd; + const auto& anf_expr = + lmbd.Lambda({"pass_name", "src_ptn_func"}, GetBody); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + ADT_LET_CONST_REF( + atomic, core_expr.template TryGet>()); + return atomic.template TryGet>(); + }()); + return lambda; + } +}; + +void ForceLinkIrTools() { + // Do nothing. +} + +REGISTER_AP_BUILTIN_MODULE("ir_tools", [](auto* m) { + LOG(ERROR) << "REGISTER_AP_BUILTIN_MODULE ir_tools"; + using Impl = PirHelperMethodClass; + m->Def("create_pass_manager", &Impl::CreatePassManager); + m->Def("create_access_topo_drr_pass", &Impl::CreateAccessTopoDrrPass); + m->Def("create_access_topo_drr_one_step_pass", + &Impl::CreateAccessTopoDrrOneStepPass); + m->Def("create_dce_pass", &Impl::CreateDeadCodeEliminationPass); + m->Def("copy_fused_ops_to_program", &Impl::CopyFusedOpsToProgram); + m->Def("match", &Impl::Match); +}); + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pass/op_factory.cc b/paddle/ap/src/paddle/pass/op_factory.cc new file mode 100644 index 00000000000000..e39cd6d9c3d4e5 --- /dev/null +++ b/paddle/ap/src/paddle/pass/op_factory.cc @@ -0,0 +1,216 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/src/paddle/pass/op_factory.h" +#include "paddle/ap/include/paddle/pir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/include/core/builtin_attribute.h" +#include "paddle/pir/include/core/builtin_op.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" +#include "paddle/pir/include/dialect/shape/ir/shape_attribute.h" + +namespace ap::paddle { + +namespace { + +adt::Result ConstructPdOpSum( + pir::Builder* builder, + const std::vector& inputs, + pir::AttributeMap attrs) { + ADT_CHECK(inputs.size() == 1); + attrs["dtype"] = ::paddle::dialect::DataTypeAttribute::get( + pir::IrContext::Instance(), phi::DataType::UNDEFINED); + auto op = builder->Build<::paddle::dialect::SumOp>(inputs.at(0), attrs); + return op; +} + +adt::Result ConstructUpSpiderOp( + pir::Builder* builder, + const std::vector& inputs, + const pir::AttributeMap& attrs) { + ADT_CHECK(inputs.size() == 2) << adt::errors::TypeError{ + std::string() + "'ap_op.up_spider' op takes 2 arguments, but " + + std::to_string(inputs.size()) + " were given"}; + auto op = builder->Build(inputs.at(0), inputs.at(1)); + return op; +} + +adt::Result ConstructYieldOp( + pir::Builder* builder, + const std::vector& inputs, + const pir::AttributeMap& attrs) { + auto op = builder->Build(inputs); + return op; +} + +adt::Result ConstructShadowOutputOp( + pir::Builder* builder, + const std::vector& inputs, + const pir::AttributeMap& attrs) { + ADT_CHECK(inputs.size() == 1); + const auto& iter = attrs.find("output_name"); + ADT_CHECK(iter != attrs.end()); + ADT_CHECK(iter->second.isa()); + const std::string& output_name = + iter->second.dyn_cast().AsString(); + auto op = builder->Build(inputs.at(0), output_name); + return op; +} + +adt::Result ConstructDownSpiderOp( + pir::Builder* builder, + const std::vector& inputs, + const pir::AttributeMap& attrs) { + ADT_CHECK(inputs.size() == 1); + auto op = builder->Build(inputs.at(0)); + return op; +} + +adt::Result ConstructLoadFromGlobalOp( + pir::Builder* builder, + const std::vector& inputs, + const pir::AttributeMap& attrs) { + ADT_CHECK(inputs.size() == 1); + const auto& iter = attrs.find("index_func_unique_id"); + ADT_CHECK(iter != attrs.end()); + ADT_CHECK(iter->second.isa()); + const std::string& unique_id = + iter->second.dyn_cast().AsString(); + auto op = + builder->Build(inputs.at(0), unique_id); + return op; +} + +adt::Result ConstructStoreToGlobalOp( + pir::Builder* builder, + const std::vector& inputs, + const pir::AttributeMap& attrs) { + ADT_CHECK(inputs.size() == 2); + const auto& iter = attrs.find("index_func_unique_id"); + ADT_CHECK(iter != attrs.end()); + ADT_CHECK(iter->second.isa()); + const std::string& unique_id = + iter->second.dyn_cast().AsString(); + auto op = builder->Build( + inputs.at(0), inputs.at(1), unique_id); + return op; +} + +adt::Result ConstructLoadFromRegisterOp( + pir::Builder* builder, + const std::vector& inputs, + const pir::AttributeMap& attrs) { + ADT_CHECK(inputs.size() == 0); + // type + const auto& type_iter = attrs.find("type"); + ADT_CHECK(type_iter != attrs.end()); + ADT_CHECK(type_iter->second.isa()); + const auto& type = type_iter->second.dyn_cast().data(); + // symbolic_shape_or_data + const auto& symbolic_shape_or_data_iter = + attrs.find("symbolic_shape_or_data"); + ADT_CHECK(symbolic_shape_or_data_iter != attrs.end()); + ADT_CHECK( + symbolic_shape_or_data_iter->second.isa()); + const auto& symbolic_shape_or_data = + symbolic_shape_or_data_iter->second + .dyn_cast() + .data(); + // name + const auto& name_iter = attrs.find("name"); + ADT_CHECK(name_iter != attrs.end()); + ADT_CHECK(name_iter->second.isa()); + const std::string& name = + name_iter->second.dyn_cast().AsString(); + // register_var_name + const auto& register_var_name_iter = attrs.find("register_var_name"); + ADT_CHECK(register_var_name_iter != attrs.end()); + ADT_CHECK(register_var_name_iter->second.isa()); + const std::string& register_var_name = + register_var_name_iter->second.dyn_cast().AsString(); + auto op = builder->Build( + type, symbolic_shape_or_data, name, register_var_name); + return op; +} + +adt::Result ConstructStoreToRegisterOp( + pir::Builder* builder, + const std::vector& inputs, + const pir::AttributeMap& attrs) { + ADT_CHECK(inputs.size() == 1); + // name + const auto& name_iter = attrs.find("name"); + ADT_CHECK(name_iter != attrs.end()); + ADT_CHECK(name_iter->second.isa()); + const std::string& name = + name_iter->second.dyn_cast().AsString(); + // register_var_name + const auto& register_var_name_iter = attrs.find("register_var_name"); + ADT_CHECK(register_var_name_iter != attrs.end()); + ADT_CHECK(register_var_name_iter->second.isa()); + const std::string& register_var_name = + register_var_name_iter->second.dyn_cast().AsString(); + auto op = builder->Build( + inputs.at(0), name, register_var_name); + return op; +} + +} // namespace + +adt::Result> CreateOperation( + pir::Builder* builder, + const std::string& op_name, + const std::vector& inputs, + const pir::AttributeMap& attrs) { + if (op_name == "pd_op.sum") { + ADT_LET_CONST_REF(ret, ConstructPdOpSum(builder, inputs, attrs)); + return ret; + } + if (op_name == "cf.yield") { + ADT_LET_CONST_REF(ret, ConstructYieldOp(builder, inputs, attrs)); + return ret; + } + if (op_name == "builtin.shadow_output") { + ADT_LET_CONST_REF(ret, ConstructShadowOutputOp(builder, inputs, attrs)); + return ret; + } + if (op_name == "ap_op.up_spider") { + ADT_LET_CONST_REF(ret, ConstructUpSpiderOp(builder, inputs, attrs)); + return ret; + } + if (op_name == "ap_op.down_spider") { + ADT_LET_CONST_REF(ret, ConstructDownSpiderOp(builder, inputs, attrs)); + return ret; + } + if (op_name == "ap_op.load_from_global") { + ADT_LET_CONST_REF(ret, ConstructLoadFromGlobalOp(builder, inputs, attrs)); + return ret; + } + if (op_name == "ap_op.store_to_global") { + ADT_LET_CONST_REF(ret, ConstructStoreToGlobalOp(builder, inputs, attrs)); + return ret; + } + if (op_name == "ap_op.load_from_register") { + ADT_LET_CONST_REF(ret, ConstructLoadFromRegisterOp(builder, inputs, attrs)); + return ret; + } + if (op_name == "ap_op.store_to_register") { + ADT_LET_CONST_REF(ret, ConstructStoreToRegisterOp(builder, inputs, attrs)); + return ret; + } + return std::nullopt; +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pass/op_factory.h b/paddle/ap/src/paddle/pass/op_factory.h new file mode 100644 index 00000000000000..65bf6bbea1b622 --- /dev/null +++ b/paddle/ap/src/paddle/pass/op_factory.h @@ -0,0 +1,39 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/pir/include/core/attribute.h" +#include "paddle/pir/include/core/builder.h" +#include "paddle/pir/include/core/type.h" + +namespace pir { + +class Operation; + +} + +namespace ap::paddle { + +// Returns nullopt if op_name not supported. +adt::Result> CreateOperation( + pir::Builder* builder, + const std::string& op_name, + const std::vector& inputs, + const pir::AttributeMap& attributes); + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/phi/ap_infer_meta_helper.cc b/paddle/ap/src/paddle/phi/ap_infer_meta_helper.cc new file mode 100644 index 00000000000000..95a3dfba82e23a --- /dev/null +++ b/paddle/ap/src/paddle/phi/ap_infer_meta_helper.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/phi/ap_infer_meta_helper.h" +#include +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/data_type.h" +#include "paddle/ap/include/axpr/interpreter.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/axpr/value_method_class.h" +#include "paddle/ap/include/paddle/builtin_frame_util.h" +#include "paddle/ap/include/paddle/const_meta_tensor_ptr.h" +#include "paddle/ap/include/paddle/const_meta_tensor_ptr_method_class.h" +#include "paddle/ap/include/paddle/const_std_vector_const_meta_tensor_ptr_ptr_method_class.h" +#include "paddle/ap/include/paddle/ddim.h" +#include "paddle/ap/include/paddle/ddim_method_class.h" +#include "paddle/ap/include/paddle/meta_tensor_ptr.h" +#include "paddle/ap/include/paddle/meta_tensor_ptr_method_class.h" +#include "paddle/ap/include/paddle/std_vector_meta_tensor_ptr_ptr_method_class.h" + +namespace phi { + +namespace { + +using CoreExpr = ap::axpr::CoreExpr; +using Lambda = ap::axpr::Lambda; + +adt::Result InferMetaByLambda( + const Lambda& lambda, + const std::vector* inputs, + std::vector* outputs) { + ap::memory::Guard guard{}; + ap::axpr::Interpreter interpreter( + ap::paddle::MakeBuiltinFrameAttrMap(), + guard.circlable_ref_list()); + ADT_RETURN_IF_ERR(interpreter.Interpret( + lambda, + {ap::paddle::GetConstStdVectorConstMetaTensorPtrPtrClass().New(inputs), + ap::paddle::GetStdVectorMetaTensorPtrPtrClass().New(outputs)})); + return adt::Ok{}; +} + +adt::Result MakeLambda(const std::string& lambda_str) { + ADT_LET_CONST_REF(anf_expr, ap::axpr::MakeAnfExprFromJsonString(lambda_str)); + const auto& core_expr = ap::axpr::ConvertAnfExprToCoreExpr(anf_expr); + ADT_LET_CONST_REF(atomic, + core_expr.TryGet>()) + << adt::errors::TypeError{ + std::string() + + "lambda_str can not be converted to atomic AnfExpr."}; + ADT_LET_CONST_REF(lambda, + atomic.TryGet>()); + return lambda; +} + +using MakeLambdaT = adt::Result (*)(const std::string& lambda_str); + +template +adt::Result CacheConvertResult(const std::string& lambda_str) { + static std::unordered_map> cache; + static std::mutex mutex; + std::unique_lock lock(mutex); + auto iter = cache.find(lambda_str); + if (iter == cache.end()) { + iter = cache.emplace(lambda_str, Make(lambda_str)).first; + } + ADT_LET_CONST_REF(lambda, iter->second); + return lambda; +} + +constexpr MakeLambdaT CastToLambda = &CacheConvertResult<&MakeLambda>; + +} // namespace + +adt::Result ApInferMetaHelper::InferMeta( + const std::string& lambda_str, + const std::vector* inputs, + std::vector* outputs) { + ADT_LET_CONST_REF(lambda, CastToLambda(lambda_str)); + return InferMetaByLambda(lambda, inputs, outputs); +} + +} // namespace phi diff --git a/paddle/ap/src/paddle/phi/ap_unary_kernel.cc b/paddle/ap/src/paddle/phi/ap_unary_kernel.cc new file mode 100644 index 00000000000000..63a9bef64db131 --- /dev/null +++ b/paddle/ap/src/paddle/phi/ap_unary_kernel.cc @@ -0,0 +1,304 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/kernel_dispatch/ap_unary_kernel.h" + +#include +#include +#include +#include "glog/logging.h" +#include "paddle/common/enforce.h" + +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/kernel_dispatch/builtin_frame_util.h" +#include "paddle/ap/include/paddle/phi/kernel_define_helper.h" +#include "paddle/ap/include/paddle/phi/kernel_dispatch_helper.h" +#include "paddle/ap/include/rt_module/naive_module_maker.h" + +namespace ap { + +using MakeCoreExprT = adt::Result> (*)( + const std::string& json_str); + +adt::Result> ConvertToCoreExpr( + const std::string& json_str) { + const auto& anf_expr = ap::axpr::MakeAnfExprFromJsonString(json_str); + ADT_RETURN_IF_ERR(anf_expr); + const auto& core_expr = + ap::axpr::ConvertAnfExprToCoreExpr(anf_expr.GetOkValue()); + if (!core_expr.Has>()) { + return adt::errors::TypeError{ + std::string() + "json_str can not be converted to atomic AnfExpr."}; + } + const auto& atomic = core_expr.Get>(); + if (!atomic.Has>()) { + return adt::errors::TypeError{ + std::string() + "json_str can not be converted to lambda AnfExpr."}; + } + return atomic.Get>(); +} + +template +adt::Result> CacheCoreExpr( + const std::string& json_str) { + static std::unordered_map>> + json_str2cache; + static std::mutex mutex; + std::unique_lock lock(mutex); + auto iter = json_str2cache.find(json_str); + if (iter == json_str2cache.end()) { + const auto& core_expr = MakeCoreExpr(json_str); + iter = json_str2cache.emplace(json_str, core_expr).first; + } + return iter->second; +} + +constexpr MakeCoreExprT MakeOrGetCoreExpr = &CacheCoreExpr<&ConvertToCoreExpr>; + +namespace kernel_dispatch { + +using FuncName2ArgTypes = + std::unordered_map>; +FuncName2ArgTypes MakeFuncName2ArgTypes(const code_module::CodeModule& m) { + auto GetArgTypes = [&](const auto& declare) { return declare->arg_types; }; + FuncName2ArgTypes ret; + for (const auto& declare : *m->func_declares) { + ret[declare->func_id] = GetArgTypes(declare); + } + return ret; +} + +using MakeRtModuleT = adt::Result (*)( + const std::string& code_module_lambda); + +template +adt::Result CacheRtModule( + const std::string& code_module_lambda) { + using Definer2RtModule = + std::unordered_map>; + static Definer2RtModule definer2rt_module; + static std::mutex mutex; + std::unique_lock lock(mutex); + auto iter = definer2rt_module.find(code_module_lambda); + if (iter == definer2rt_module.end()) { + const auto& rt_module = MakeRtModule(code_module_lambda); + iter = definer2rt_module.emplace(code_module_lambda, rt_module).first; + } + return iter->second; +} + +adt::Result MakeRtModule( + const std::string& code_module_lambda) { + ADT_LET_CONST_REF(code_module_core_expr, + MakeOrGetCoreExpr(code_module_lambda)); + phi::KernelDefineHelper helper{}; + ADT_LET_CONST_REF(code_module, + helper.InterpretKernelDefineLambda(code_module_core_expr)); + using RetT = adt::Result; + return code_module->source_code.Match( + [&](const ap::code_module::Project&) -> RetT { + const char* ap_workspace_dir = std::getenv("AP_WORKSPACE_DIR"); + ADT_CHECK(ap_workspace_dir != nullptr) << adt::errors::TypeError{ + std::string() + "AP_WORKSPACE_DIR not set"}; + auto hash_value_str = + std::to_string(std::hash()(code_module_lambda)); + std::string workspace_dir = + std::string(ap_workspace_dir) + "/" + hash_value_str; + ap::rt_module::NaiveModuleMaker maker(workspace_dir); + auto Serialize = [&](const auto&) -> const std::string& { + return code_module_lambda; + }; + ADT_LET_CONST_REF(rt_module, maker.Make(code_module, Serialize)); + return rt_module; + }, + [&](const ap::code_module::Package&) -> RetT { + const char* ap_workspace_dir = std::getenv("AP_PACKAGE_DIR"); + ap::rt_module::NaiveModuleMaker maker(ap_workspace_dir); + auto Serialize = [&](const auto&) -> const std::string& { + return code_module_lambda; + }; + ADT_LET_CONST_REF(rt_module, maker.Make(code_module, Serialize)); + return rt_module; + }); +} + +constexpr MakeRtModuleT MakeOrGetRtModule = &CacheRtModule<&MakeRtModule>; + +adt::List MakeTensorDims(const phi::DenseTensor& tensor) { + adt::List ret; + ret->reserve(tensor.dims().size()); + for (int i = 0; i < tensor.dims().size(); ++i) { + ret->emplace_back(Val{tensor.dims().at(i)}); + } + return ret; +} + +adt::Result> GetIndexesSlices( + const ap::axpr::AttrMap& + kernel_dispatch_const_data, + const std::string& attr_name) { + ADT_LET_CONST_REF( + val, + kernel_dispatch_const_data + ->TryGet>(attr_name)); + return val; +} + +template +adt::Result VisitTensorIdxOrRange( + const adt::List& list, + const DoEachIdxT& DoEachIdx, + const DoEachRangeT& DoEachRange) { + using Ok = adt::Result; + for (int i = 0; i < list->size(); ++i) { + const auto& elt = list->at(i); + ADT_RETURN_IF_ERR(elt.Match( + [&](int64_t idx) -> Ok { + ADT_RETURN_IF_ERR(DoEachIdx(idx)); + return adt::Ok{}; + }, + [&](const adt::List& range_val) -> Ok { + ADT_CHECK(range_val->size() == 2); + ADT_LET_CONST_REF(start, range_val->at(0).TryGet()); + ADT_LET_CONST_REF(end, range_val->at(1).TryGet()); + ADT_RETURN_IF_ERR(DoEachRange(start, end)); + return adt::Ok{}; + }, + [&](const auto&) -> Ok { + return adt::errors::TypeError{"only index or index pair supported."}; + })); + } + return adt::Ok{}; +} + +adt::Result> MakeConstTensors( + const std::vector& xs, + const ap::axpr::AttrMap& + kernel_dispatch_const_data) { + ADT_LET_CONST_REF( + indexes_slices, + GetIndexesSlices(kernel_dispatch_const_data, + "__builtin_ap_kernel_input_indexes_slices")); + adt::List ret; + ret->reserve(xs.size()); + using Ok = adt::Result; + auto CollectTensor = [&](adt::List* list, + const phi::DenseTensor* x) -> Ok { + ConstTensorData tensor_data{x}; + adt::List dims{MakeTensorDims(*x)}; + ConstTensor const_tensor{tensor_data, dims}; + axpr::BuiltinClassInstance instance{GetConstTensorClass(), + const_tensor}; + (*list)->emplace_back(instance); + return adt::Ok{}; + }; + auto DoEachIdx = [&](std::size_t i) -> Ok { + ADT_CHECK(i < xs.size()); + const auto* x = xs.at(i); + ADT_RETURN_IF_ERR(CollectTensor(&ret, x)); + return adt::Ok{}; + }; + auto DoEachRange = [&](std::size_t start, std::size_t end) -> Ok { + ADT_CHECK(start <= end); + adt::List tensor_list; + tensor_list->reserve(end - start); + for (int i = start; i < end; ++i) { + ADT_CHECK(i < xs.size()); + const auto* x = xs.at(i); + ADT_RETURN_IF_ERR(CollectTensor(&tensor_list, x)); + } + ret->emplace_back(tensor_list); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + VisitTensorIdxOrRange(indexes_slices, DoEachIdx, DoEachRange)); + return ret; +} + +adt::Result> MakeMutableTensors( + const std::vector& xs, + const ap::axpr::AttrMap& + kernel_dispatch_const_data) { + ADT_LET_CONST_REF( + indexes_slices, + GetIndexesSlices(kernel_dispatch_const_data, + "__builtin_ap_kernel_output_indexes_slices")); + adt::List ret; + ret->reserve(xs.size()); + + using Ok = adt::Result; + auto CollectTensor = [&](adt::List* list, phi::DenseTensor* x) -> Ok { + MutableTensorData tensor_data{x}; + adt::List dims{MakeTensorDims(*x)}; + MutableTensor mutable_tensor{tensor_data, dims}; + axpr::BuiltinClassInstance instance{GetMutableTensorClass(), + mutable_tensor}; + (*list)->emplace_back(instance); + return adt::Ok{}; + }; + auto DoEachIdx = [&](std::size_t i) -> Ok { + ADT_CHECK(i < xs.size()); + auto* x = xs.at(i); + ADT_RETURN_IF_ERR(CollectTensor(&ret, x)); + return adt::Ok{}; + }; + auto DoEachRange = [&](std::size_t start, std::size_t end) -> Ok { + ADT_CHECK(start <= end); + adt::List tensor_list; + tensor_list->reserve(end - start); + for (int i = start; i < end; ++i) { + ADT_CHECK(i < xs.size()); + auto* x = xs.at(i); + ADT_RETURN_IF_ERR(CollectTensor(&tensor_list, x)); + } + ret->emplace_back(tensor_list); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + VisitTensorIdxOrRange(indexes_slices, DoEachIdx, DoEachRange)); + return ret; +} + +adt::Result ApUnaryKernel( + const DeviceCtx& device_ctx, + const std::vector& xs, + int num_outputs, + const std::string& code_module_lambda, + const std::string& infer_meta_lambda, + const std::string& kernel_dispatch_lambda, + const std::string& kernel_dispatch_const_data_lambda, + std::vector outs) { + phi::KernelDispatchHelper helper{}; + ADT_LET_CONST_REF(ctx_maker_lambda, + MakeOrGetCoreExpr(kernel_dispatch_const_data_lambda)); + ADT_LET_CONST_REF(ctx_maker_ret, helper.InterpretCtxMaker(ctx_maker_lambda)); + ADT_LET_CONST_REF( + kernel_dispatch_const_data, + ctx_maker_ret.TryGet>()); + ADT_LET_CONST_REF(rt_module, + kernel_dispatch::MakeOrGetRtModule(code_module_lambda)); + ADT_LET_CONST_REF(inputs, MakeConstTensors(xs, kernel_dispatch_const_data)); + ADT_LET_CONST_REF(outputs, + MakeMutableTensors(outs, kernel_dispatch_const_data)); + DispatchRawCtx raw_ctx{device_ctx, inputs, outputs, rt_module}; + DispatchCtx dispatch_ctx{raw_ctx, kernel_dispatch_const_data}; + ADT_LET_CONST_REF(lambda, MakeOrGetCoreExpr(kernel_dispatch_lambda)); + ADT_RETURN_IF_ERR(helper.InterpretKernelDispatcher(lambda, dispatch_ctx)); + return adt::Ok{}; +} + +} // namespace kernel_dispatch + +} // namespace ap diff --git a/paddle/ap/src/paddle/phi/kernel_define_helper.cc b/paddle/ap/src/paddle/phi/kernel_define_helper.cc new file mode 100644 index 00000000000000..805396bca548a3 --- /dev/null +++ b/paddle/ap/src/paddle/phi/kernel_define_helper.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/phi/kernel_define_helper.h" +#include "paddle/ap/include/axpr/interpreter.h" +#include "paddle/ap/include/code_module/builtin_frame_util.h" +#include "paddle/ap/include/code_module/value.h" +#include "paddle/ap/include/code_module/value_method_class.h" +#include "paddle/ap/include/memory/guard.h" + +namespace phi { + +namespace { + +using CoreExpr = ap::axpr::CoreExpr; + +using Lambda = ap::axpr::Lambda; + +using CodeModule = ap::code_module::CodeModule; + +using Val = ap::code_module::Value; + +} // namespace + +adt::Result KernelDefineHelper::InterpretKernelDefineLambda( + const Lambda& lambda) { + ap::memory::Guard guard{}; + ap::axpr::Interpreter cps_interpreter( + ap::code_module::MakeBuiltinFrameAttrMap(), + guard.circlable_ref_list()); + ADT_LET_CONST_REF(interpret_ret, cps_interpreter.Interpret(lambda, {})); + ADT_LET_CONST_REF(m, ap::axpr::Get(interpret_ret)); + return m; +} + +} // namespace phi diff --git a/paddle/ap/src/paddle/phi/kernel_dispatch_helper.cc b/paddle/ap/src/paddle/phi/kernel_dispatch_helper.cc new file mode 100644 index 00000000000000..9fd4e9578928ac --- /dev/null +++ b/paddle/ap/src/paddle/phi/kernel_dispatch_helper.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/phi/kernel_dispatch_helper.h" +#include "paddle/ap/include/axpr/interpreter.h" +#include "paddle/ap/include/kernel_dispatch/builtin_frame_util.h" +#include "paddle/ap/include/kernel_dispatch/dispatch_ctx_method_class.h" +#include "paddle/ap/include/kernel_dispatch/value.h" +#include "paddle/ap/include/memory/guard.h" + +namespace phi { + +namespace { + +using CoreExpr = ap::axpr::CoreExpr; +using Lambda = ap::axpr::Lambda; +using Val = ap::kernel_dispatch::Val; +using DispatchCtx = ap::kernel_dispatch::DispatchCtx; + +} // namespace + +KernelDispatchHelper::KernelDispatchHelper() + : circlable_ref_list_(ap::memory::Guard{}.circlable_ref_list()) {} + +adt::Result KernelDispatchHelper::InterpretCtxMaker( + const Lambda& ctx_maker_lambda) { + ap::axpr::Interpreter cps_interpreter( + ap::kernel_dispatch::MakeBuiltinFrameAttrMap(), circlable_ref_list_); + ADT_LET_CONST_REF(ctx, cps_interpreter.Interpret(ctx_maker_lambda, {})); + return ctx; +} + +adt::Result KernelDispatchHelper::InterpretKernelDispatcher( + const Lambda& kernel_dispatch_lambda, const DispatchCtx& dispatch_ctx) { + const auto& cls = ap::kernel_dispatch::GetDispatchCtxClass(); + ap::axpr::BuiltinClassInstance instance{cls, dispatch_ctx}; + ap::axpr::Interpreter cps_interpreter( + ap::kernel_dispatch::MakeBuiltinFrameAttrMap(), circlable_ref_list_); + ADT_RETURN_IF_ERR( + cps_interpreter.Interpret(kernel_dispatch_lambda, {instance})); + return adt::Ok{}; +} + +} // namespace phi diff --git a/paddle/ap/src/paddle/pir/attribute_method_class.cc b/paddle/ap/src/paddle/pir/attribute_method_class.cc new file mode 100644 index 00000000000000..4478a6964203cf --- /dev/null +++ b/paddle/ap/src/paddle/pir/attribute_method_class.cc @@ -0,0 +1,584 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/paddle/pir/attribute_method_class.h" +#include "paddle/ap/include/axpr/abstract_list.h" +#include "paddle/ap/include/axpr/callable_helper.h" +#include "paddle/ap/include/paddle/phi/place_method_class.h" +#include "paddle/ap/include/paddle/pir/shape_or_data_method_class.h" +#include "paddle/ap/include/paddle/pir/type_method_class.h" + +namespace ap::paddle { + +inline adt::Result PirAttributeToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + ss << self; + return ss.str(); +} + +struct PirAttributeGetType { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const auto& attr_type_id = GetAttrAdtTypeId(self); + return attr_type_id.Match([&](const auto& impl) -> std::string { + using T = typename std::decay_t::type; + return T::name(); + }); + } +}; + +struct PirAttributeMatch { + static adt::Result Call( + axpr::InterpreterBase* interpreter, + const axpr::Value& self_val, + const std::vector& packed_args_val) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const auto& attr_type_id = GetAttrAdtTypeId(self); + const auto& type_name = + attr_type_id.Match([&](const auto& impl) -> std::string { + using T = typename std::decay_t::type; + return T::name(); + }); + const auto& packed_args = + axpr::CastToPackedArgs(packed_args_val); + const auto& [args, kwargs] = *packed_args; + ADT_CHECK(args->size() == 0) << adt::errors::TypeError{ + std::string() + + "PirAttribute.match() supports keyword arguments only, but " + + std::to_string(args->size()) + " positional arguments were given"}; + std::string key = type_name; + if (!kwargs->Has(type_name)) { + if (!kwargs->Has("_")) { + return adt::errors::TypeError{ + std::string() + "PirAttribute.match() failed. no keyword '" + + type_name + "' or '_' provided"}; + } + key = "_"; + } + ADT_LET_CONST_REF(func, kwargs->Get(key)); + ADT_CHECK(axpr::CallableHelper{}.IsCallable(func)) + << adt::errors::TypeError{ + std::string() + + "the arguments of PirAttribute.match() should be callable"}; + if (key == "_") { + return interpreter->InterpretCall(func, {}); + } else { + auto PatternMatch = + [&](const auto& impl) -> adt::Result> { + using T = typename std::decay_t::type; + return MakePirAttributeImpl::GetCallArgs(self_val); + }; + ADT_LET_CONST_REF(attr_make_args, attr_type_id.Match(PatternMatch)); + return interpreter->InterpretCall(func, attr_make_args.vector()); + } + } +}; + +axpr::TypeImpl> GetPirAttributeClass() { + static auto cls(axpr::MakeBuiltinClass( + "PirAttribute", [&](const auto& Yield) { + Yield("__str__", &PirAttributeToString); + Yield("get_type", &PirAttributeGetType::Call); + Yield("match", &PirAttributeMatch::Call); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +adt::Result MakePirAttributeImplBoolAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(bool_val, args.at(0).template CastTo()); + pir::Attribute attr{ + pir::BoolAttribute::get(pir::IrContext::Instance(), bool_val)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplBoolAttribute::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa()); + const auto& attr = attribute.dyn_cast(); + axpr::Value val{attr.data()}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplComplex64Attribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(data_val, args.at(0).template CastTo()); + ADT_LET_CONST_REF(complex_val, data_val.template TryGet()); + pir::Attribute attr{ + pir::Complex64Attribute::get(pir::IrContext::Instance(), complex_val)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplComplex64Attribute::GetCallArgs( + const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa()); + const auto& attr = attribute.dyn_cast(); + axpr::Value val{axpr::DataValue{attr.data()}}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplComplex128Attribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(data_val, args.at(0).template CastTo()); + ADT_LET_CONST_REF(val, data_val.template TryGet()); + pir::Attribute attr{ + pir::Complex128Attribute::get(pir::IrContext::Instance(), val)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplComplex128Attribute::GetCallArgs( + const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa()); + const auto& attr = attribute.dyn_cast(); + axpr::Value val{axpr::DataValue{attr.data()}}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplFloatAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(data_val, args.at(0).template CastTo()); + ADT_LET_CONST_REF(val, data_val.template TryGet()); + pir::Attribute attr{ + pir::FloatAttribute::get(pir::IrContext::Instance(), val)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplFloatAttribute::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa()); + const auto& attr = attribute.dyn_cast(); + axpr::Value val{axpr::DataValue{attr.data()}}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplDoubleAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(data_val, args.at(0).template CastTo()); + ADT_LET_CONST_REF(val, data_val.template TryGet()); + pir::Attribute attr{ + pir::DoubleAttribute::get(pir::IrContext::Instance(), val)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplDoubleAttribute::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa()); + const auto& attr = attribute.dyn_cast(); + axpr::Value val{axpr::DataValue{attr.data()}}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplInt32Attribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(data_val, args.at(0).template CastTo()); + ADT_LET_CONST_REF(val, data_val.template TryGet()); + pir::Attribute attr{ + pir::Int32Attribute::get(pir::IrContext::Instance(), val)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplInt32Attribute::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa()); + const auto& attr = attribute.dyn_cast(); + axpr::Value val{axpr::DataValue{attr.data()}}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplIndexAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(data_val, args.at(0).template CastTo()); + ADT_LET_CONST_REF(val, data_val.template TryGet()); + pir::Attribute attr{ + pir::IndexAttribute::get(pir::IrContext::Instance(), val)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplIndexAttribute::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa()); + const auto& attr = attribute.dyn_cast(); + axpr::Value val{axpr::DataValue{attr.data()}}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplInt64Attribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(data_val, args.at(0).template CastTo()); + ADT_LET_CONST_REF(val, data_val.template TryGet()); + pir::Attribute attr{ + pir::Int64Attribute::get(pir::IrContext::Instance(), val)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplInt64Attribute::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa()); + const auto& attr = attribute.dyn_cast(); + axpr::Value val{axpr::DataValue{attr.data()}}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplPointerAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(data_val, args.at(0).template CastTo()); + ADT_LET_CONST_REF(val, data_val.template TryGet()); + pir::Attribute attr{ + pir::PointerAttribute::get(pir::IrContext::Instance(), val)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplPointerAttribute::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa()); + const auto& attr = attribute.dyn_cast(); + axpr::Value val{axpr::PointerValue{attr.data()}}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplTypeAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(type_val, args.at(0).template CastTo()); + pir::Attribute attr{ + pir::TypeAttribute::get(pir::IrContext::Instance(), type_val)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplTypeAttribute::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa()); + const auto& attr = attribute.dyn_cast(); + axpr::Value val{GetPirTypeClass().New(attr.data())}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplStrAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(val, args.at(0).template CastTo()); + pir::Attribute attr{pir::StrAttribute::get(pir::IrContext::Instance(), val)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplStrAttribute::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa()); + const auto& attr = attribute.dyn_cast(); + axpr::Value val{attr.AsString()}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplArrayAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "pir.t_vec() takes 1 argument, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(lst, args.at(0).template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the argument of pir.t_vec() should be a list (not a " + + axpr::GetTypeName(args.at(0)) + ")"}; + std::vector attrs; + attrs.reserve(lst->size()); + for (const auto& arg : *lst) { + ADT_LET_CONST_REF(elt, arg.template CastTo()); + attrs.emplace_back(elt); + } + pir::Attribute attr{ + pir::ArrayAttribute::get(pir::IrContext::Instance(), attrs)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplArrayAttribute::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa()); + const auto& attr = attribute.dyn_cast(); + adt::List lst{}; + for (int i = 0; i < attr.size(); ++i) { + lst->emplace_back(GetPirAttributeClass().New(attr.at(i))); + } + return adt::List{axpr::Value{lst}}; +} + +adt::Result MakePirAttributeImplTensorNameAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(val, args.at(0).template CastTo()); + pir::Attribute attr{ + pir::TensorNameAttribute::get(pir::IrContext::Instance(), val)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplTensorNameAttribute::GetCallArgs( + const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa()); + const auto& attr = attribute.dyn_cast(); + axpr::Value val{attr.data()}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplSymbolAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(shape_or_data, + args.at(0).template CastTo()); + pir::Attribute attr{pir::shape::SymbolAttribute::get( + pir::IrContext::Instance(), shape_or_data)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplSymbolAttribute::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa()); + const auto& attr = attribute.dyn_cast(); + axpr::Value val{GetPirShapeOrDataClass().New(attr.data())}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplKernelAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + return adt::errors::NotImplementedError{ + std::string() + "pir." + ::paddle::dialect::KernelAttribute::name() + + "() is not implemneted"}; +} + +adt::Result> +MakePirAttributeImplKernelAttribute::GetCallArgs(const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirAttributeImplIntArrayAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + ::paddle::dialect::IntArrayAttribute::name() + + "() takes 1 argument, but " + std::to_string(args.size()) + + " were given"}; + ADT_LET_CONST_REF(lst, axpr::AbstractList::CastFrom(args.at(0))) + << adt::errors::TypeError{ + std::string() + "the argument 1 of" + + ::paddle::dialect::IntArrayAttribute::name() + + "() should be a list/SerializableList/MutableList (not " + + axpr::GetTypeName(args.at(0)) + ")"}; + std::vector int_array; + ADT_LET_CONST_REF(lst_size, lst.size()); + int_array.reserve(lst_size); + ADT_RETURN_IF_ERR( + lst.Visit([&](const auto& arg) -> adt::Result { + ADT_LET_CONST_REF(elt, arg.template CastTo()); + int_array.emplace_back(elt); + return adt::Continue{}; + })); + pir::Attribute attr{::paddle::dialect::IntArrayAttribute::get( + pir::IrContext::Instance(), int_array)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplIntArrayAttribute::GetCallArgs( + const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa<::paddle::dialect::IntArrayAttribute>()); + const auto& attr = attribute.dyn_cast<::paddle::dialect::IntArrayAttribute>(); + adt::List lst{}; + const auto& data = attr.data(); + for (int i = 0; i < data.size(); ++i) { + int64_t elt = data[i]; + lst->emplace_back(elt); + } + return adt::List{axpr::Value{lst}}; +} + +inline adt::Result ConvertDataValueToScalar( + const axpr::DataValue& data_val) { + return ScalarHelper{}.ConvertFromDataType(data_val); +} + +adt::Result MakePirAttributeImplScalarAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(data_val, args.at(0).template CastTo()); + ADT_LET_CONST_REF(val, ConvertDataValueToScalar(data_val)); + pir::Attribute attr{ + ::paddle::dialect::ScalarAttribute::get(pir::IrContext::Instance(), val)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplScalarAttribute::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa<::paddle::dialect::ScalarAttribute>()); + const auto& attr = attribute.dyn_cast<::paddle::dialect::ScalarAttribute>(); + ADT_LET_CONST_REF(data_value, ScalarHelper{}.ConvertToDataValue(attr.data())); + axpr::Value val{data_value}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplDataTypeAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + std::optional opt_phi_data_type; + if (args.at(0).template CastableTo()) { + ADT_LET_CONST_REF(type, args.at(0).template CastTo()); + opt_phi_data_type = ::paddle::dialect::TransToPhiDataType(type); + } else if (args.at(0).template CastableTo()) { + ADT_LET_CONST_REF(data_type, args.at(0).template CastTo()); + ADT_LET_CONST_REF(phi_data_type, + axpr::GetPhiDataTypeFromDataType(data_type)); + opt_phi_data_type = phi_data_type; + } else { + return adt::errors::TypeError{ + "the argument 1 of t_dtype() should be a DataType/PirType (not a " + + axpr::GetTypeName(args.at(0)) + ")"}; + } + pir::Attribute attr{::paddle::dialect::DataTypeAttribute::get( + pir::IrContext::Instance(), opt_phi_data_type.value())}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplDataTypeAttribute::GetCallArgs( + const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa<::paddle::dialect::DataTypeAttribute>()); + const auto& attr = attribute.dyn_cast<::paddle::dialect::DataTypeAttribute>(); + ADT_LET_CONST_REF(data_type, axpr::GetDataTypeFromPhiDataType(attr.data())); + axpr::Value val{data_type}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplPlaceAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(place, args.at(0).template CastTo()); + pir::Attribute attr{::paddle::dialect::PlaceAttribute::get( + pir::IrContext::Instance(), place)}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplPlaceAttribute::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa<::paddle::dialect::PlaceAttribute>()); + const auto& attr = attribute.dyn_cast<::paddle::dialect::PlaceAttribute>(); + axpr::Value val{GetPlaceClass().New(attr.data())}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplDataLayoutAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(data_layout_str, args.at(0).template CastTo()); + std::optional<::common::DataLayout> data_layout; + try { + data_layout = ::common::StringToDataLayout(data_layout_str); + } catch (const std::exception&) { + return adt::errors::ValueError{"StringToDataLayout('" + data_layout_str + + "') failed"}; + } + pir::Attribute attr{::paddle::dialect::DataLayoutAttribute::get( + pir::IrContext::Instance(), data_layout.value())}; + return GetPirAttributeClass().New(attr); +} + +adt::Result> +MakePirAttributeImplDataLayoutAttribute::GetCallArgs( + const axpr::Value& self_val) { + ADT_LET_CONST_REF(attribute, self_val.template CastTo()); + ADT_CHECK(attribute.isa<::paddle::dialect::DataLayoutAttribute>()); + const auto& attr = + attribute.dyn_cast<::paddle::dialect::DataLayoutAttribute>(); + std::string data_layout_str; + try { + data_layout_str = ::common::DataLayoutToString(attr.data()); + } catch (const std::exception& e) { + return adt::errors::ValueError{e.what()}; + } + axpr::Value val{data_layout_str}; + return adt::List{val}; +} + +adt::Result MakePirAttributeImplGroupInfoAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + return adt::errors::NotImplementedError{ + std::string() + "pir." + ::cinn::dialect::GroupInfoAttribute::name() + + "() is not implemneted"}; +} + +adt::Result> +MakePirAttributeImplGroupInfoAttribute::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirAttributeImplCINNKernelInfoAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + return adt::errors::NotImplementedError{ + std::string() + "pir." + + ::cinn::dialect::CINNKernelInfoAttribute::name() + + "() is not implemneted"}; +} + +adt::Result> +MakePirAttributeImplCINNKernelInfoAttribute::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirAttributeImplUnclassifiedAttribute::Call( + const axpr::Value& self_val, const std::vector& args) { + return adt::errors::NotImplementedError{std::string() + "pir." + + UnclassifiedAttribute::name() + + "() is not implemneted"}; +} + +adt::Result> +MakePirAttributeImplUnclassifiedAttribute::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pir/manual_op.cc b/paddle/ap/src/paddle/pir/manual_op.cc new file mode 100644 index 00000000000000..4e96d190e08305 --- /dev/null +++ b/paddle/ap/src/paddle/pir/manual_op.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/ap/include/paddle/pir/manual_op.h" +#include "paddle/common/enforce.h" +#include "paddle/pir/include/core/builtin_attribute.h" +#include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/dialect/shape/ir/shape_attribute.h" + +namespace ap::dialect { + +void UpSpiderOp::Build(pir::Builder& builder, // NOLINT + pir::OperationArgument& argument, // NOLINT + pir::Value lhs, + pir::Value rhs) { + argument.AddInput(lhs); + argument.AddInput(rhs); +} + +void DownSpiderOp::Build(pir::Builder& builder, + pir::OperationArgument& argument, + pir::Value x) { + argument.inputs = {x}; + argument.output_types = {x.type()}; +} + +bool DownSpiderOp::InferSymbolicShape( + pir::InferSymbolicShapeContext* infer_context) { + infer_context->SetShapeOrDataForValue( + result(0), infer_context->GetShapeOrDataForValue(operand_source(0))); + return true; +} + +const char* + LoadFromRegisterOp::attributes_name[LoadFromRegisterOp::attributes_num] = { + "type", "symbolic_shape_or_data", "name", "register_var_name"}; + +void LoadFromRegisterOp::Build(pir::Builder& builder, + pir::OperationArgument& argument, + pir::Type output_type, + const symbol::ShapeOrDataDimExprs& shape_or_data, + const std::string& name, + const std::string& register_var_name) { + argument.inputs = {}; + argument.output_types = {output_type}; + argument.AddAttribute( + "type", pir::TypeAttribute::get(pir::IrContext::Instance(), output_type)); + argument.AddAttribute("symbolic_shape_or_data", + pir::shape::SymbolAttribute::get( + pir::IrContext::Instance(), shape_or_data)); + argument.AddAttribute( + "name", pir::StrAttribute::get(pir::IrContext::Instance(), name)); + argument.AddAttribute( + "register_var_name", + pir::StrAttribute::get(pir::IrContext::Instance(), register_var_name)); +} + +bool LoadFromRegisterOp::InferSymbolicShape( + pir::InferSymbolicShapeContext* infer_context) { + const auto& symbolic_shape_or_data = + this->attributes() + .at("symbolic_shape_or_data") + .dyn_cast() + .data(); + infer_context->SetShapeOrDataForValue(result(0), symbolic_shape_or_data); + return true; +} + +const char* + StoreToRegisterOp::attributes_name[StoreToRegisterOp::attributes_num] = { + "name", "register_var_name"}; + +void StoreToRegisterOp::Build(pir::Builder& builder, + pir::OperationArgument& argument, + pir::Value x, + const std::string& name, + const std::string& register_var_name) { + argument.inputs = {x}; + argument.AddAttribute( + "name", pir::StrAttribute::get(pir::IrContext::Instance(), name)); + argument.AddAttribute( + "register_var_name", + pir::StrAttribute::get(pir::IrContext::Instance(), register_var_name)); +} + +bool StoreToRegisterOp::InferSymbolicShape( + pir::InferSymbolicShapeContext* infer_context) { + return true; +} + +const char* + LoadFromGlobalOp::attributes_name[LoadFromGlobalOp::attributes_num] = { + "index_func_unique_id"}; + +void LoadFromGlobalOp::Build(pir::Builder& builder, + pir::OperationArgument& argument, + pir::Value input, + const std::string& index_func_unique_id) { + argument.inputs = {input}; + argument.output_types = {input.type()}; + argument.AddAttribute( + "index_func_unique_id", + pir::StrAttribute::get(pir::IrContext::Instance(), index_func_unique_id)); +} + +bool LoadFromGlobalOp::InferSymbolicShape( + pir::InferSymbolicShapeContext* infer_context) { + infer_context->SetShapeOrDataForValue( + result(0), infer_context->GetShapeOrDataForValue(operand_source(0))); + return true; +} + +const char* StoreToGlobalOp::attributes_name[StoreToGlobalOp::attributes_num] = + {"index_func_unique_id"}; + +void StoreToGlobalOp::Build(pir::Builder& builder, + pir::OperationArgument& argument, + pir::Value var, + pir::Value val, + const std::string& index_func_unique_id) { + argument.inputs = {var, val}; + argument.output_types = {}; + argument.AddAttribute( + "index_func_unique_id", + pir::StrAttribute::get(pir::IrContext::Instance(), index_func_unique_id)); +} + +bool StoreToGlobalOp::InferSymbolicShape( + pir::InferSymbolicShapeContext* infer_context) { + return true; +} + +} // namespace ap::dialect + +IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::UpSpiderOp); +IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::DownSpiderOp); +IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::LoadFromRegisterOp); +IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::StoreToRegisterOp); +IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::LoadFromGlobalOp); +IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::StoreToGlobalOp); diff --git a/paddle/ap/src/paddle/pir/op_dialect.cc b/paddle/ap/src/paddle/pir/op_dialect.cc new file mode 100644 index 00000000000000..ee0d69720303b1 --- /dev/null +++ b/paddle/ap/src/paddle/pir/op_dialect.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/pir/op_dialect.h" +#include "paddle/ap/include/paddle/pir/manual_op.h" + +namespace ap { +namespace dialect { + +OperatorDialect::OperatorDialect(::pir::IrContext *context) + : ::pir::Dialect( + name(), context, ::pir::TypeId::get()) { + this->initialize(); +} + +void OperatorDialect::initialize() { + RegisterOp(); + RegisterOp(); + RegisterOp(); + RegisterOp(); + RegisterOp(); + RegisterOp(); +} + +} // namespace dialect +} // namespace ap + +IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::OperatorDialect) diff --git a/paddle/ap/src/paddle/pir/packed_ir_op_inner_source_pattern_helper.cc b/paddle/ap/src/paddle/pir/packed_ir_op_inner_source_pattern_helper.cc new file mode 100644 index 00000000000000..09923b805b0300 --- /dev/null +++ b/paddle/ap/src/paddle/pir/packed_ir_op_inner_source_pattern_helper.cc @@ -0,0 +1,135 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/pir/packed_ir_op_inner_source_pattern_helper.h" +#include "paddle/ap/include/drr/drr_graph_descriptor.h" +#include "paddle/ap/include/drr/drr_node_descriptor.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/topo_kind.h" +#include "paddle/ap/include/drr/value_method_class.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/ir_match/graph_matcher.h" +#include "paddle/ap/include/ir_match/ir_match_ctx.h" +#include "paddle/ap/include/paddle/pir_graph_descriptor.h" +#include "paddle/ap/include/paddle/pir_node.h" +#include "paddle/ap/include/paddle/pir_node_descriptor.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +namespace ap::paddle { + +namespace { + +std::optional GetPirNodeBlock(const PirNode& pir_node) { + using RetT = std::optional; + return pir_node.Match( + [&](const NativeIrValue& impl) -> RetT { return std::nullopt; }, + [&](const PackedIrValue& impl) -> RetT { return std::nullopt; }, + [&](const NativeIrOpOperand& impl) -> RetT { return std::nullopt; }, + [&](const PackedIrOpOperand& impl) -> RetT { return std::nullopt; }, + [&](const RefIrOpOperand& impl) -> RetT { return std::nullopt; }, + [&](const NativeIrOp& impl) -> RetT { return impl.op->GetParent(); }, + [&](const PackedIrOp& impl) -> RetT { + return impl.fusion_op->GetParent(); + }, + [&](const NativeIrOpResult& impl) -> RetT { return std::nullopt; }, + [&](const PackedIrOpResult& impl) -> RetT { return std::nullopt; }, + [&](const RefIrValue& impl) -> RetT { return std::nullopt; }, + [&](const RefIrOp& impl) -> RetT { return std::nullopt; }, + [&](const RefIrOpResult& impl) -> RetT { return std::nullopt; }); +} + +adt::Result GetDrrYieldNode( + const drr::SourcePatternCtx& src_ptn_ctx) { + std::optional yield_node; + for (const auto& drr_node : src_ptn_ctx->node_arena->nodes()) { + if (!drr_node.template Has>()) { + continue; + } + ADT_LET_CONST_REF(drr_op, + drr_node.template TryGet>()); + if (drr_op->op_declare->op_name == pir::YieldOp::name()) { + ADT_CHECK(!yield_node.has_value()); + yield_node = drr_node; + } + } + ADT_CHECK(yield_node.has_value()); + return yield_node.value(); +} + +adt::Result GetPirYieldNode(const pir::Block* block) { + for (const auto& op : *block) { + if (op.template isa()) { + return NativeIrOp{const_cast(&op)}; + } + } + return adt::errors::ValueError{"no yield op found in fusion_op block"}; +} + +} // namespace + +adt::Result>> +PackedIrOpInnerSourcePatternHelper::Match( + const PackedIrOp& ir_op, const drr::SourcePatternCtx& src_ptn_ctx) { + return Match(ir_op.fusion_op.block(), src_ptn_ctx); +} + +adt::Result>> +PackedIrOpInnerSourcePatternHelper::Match( + const pir::Block* block, const drr::SourcePatternCtx& src_ptn_ctx) { + auto BelongToThisBlockOrNotOp = + [&](const PirNode& node) -> adt::Result { + const auto& opt_block = GetPirNodeBlock(node); + if (!opt_block.has_value()) { + return true; + } + return opt_block.value() == block; + }; + using Default = drr::topo_kind::Default; + using BlockBound = drr::topo_kind::BlockBound; + using DrrGraphNode = graph::Node; + ap::graph::GraphDescriptor pir_graph( + BelongToThisBlockOrNotOp); + ap::graph::GraphDescriptor src_ptn_graph{}; + ap::ir_match::GraphMatcher graph_matcher( + pir_graph, src_ptn_graph); + ADT_LET_CONST_REF(drr_yield_node, GetDrrYieldNode(src_ptn_ctx)); + ADT_LET_CONST_REF(pir_yield_node, GetPirYieldNode(block)); + ADT_LET_CONST_REF( + graph_match_ctx, + graph_matcher.MatchByAnchor(pir_yield_node, drr_yield_node.node())); + ADT_LET_CONST_REF( + graph_matched, + graph_matcher.IsGraphMatched(graph_match_ctx, drr_yield_node.node())); + if (!graph_matched) { + return std::nullopt; + } + auto GetPirNode = [](const pir::Operation* op) -> PirNode { + auto* mut_op = const_cast(op); + if (mut_op->isa<::cinn::dialect::FusionOp>()) { + return PackedIrOp{mut_op->dyn_cast<::cinn::dialect::FusionOp>()}; + } else { + return NativeIrOp{mut_op}; + } + }; + for (const auto& op : *block) { + const auto& opt_drr_node = + graph_match_ctx->GetOptMatchedSmallGraphNode(GetPirNode(&op)); + if (!opt_drr_node.has_value()) { + return std::nullopt; + } + } + return graph_match_ctx; +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pir/pass_manager_method_class.cc b/paddle/ap/src/paddle/pir/pass_manager_method_class.cc new file mode 100644 index 00000000000000..f915b218da6cd5 --- /dev/null +++ b/paddle/ap/src/paddle/pir/pass_manager_method_class.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/paddle/pir/pass_manager_method_class.h" + +namespace ap::paddle { + +struct PirPassManagerMethodClass { + using Self = ap::paddle::PassManager; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self.shared_ptr().get(); + std::ostringstream ss; + ss << ""; + return ss.str(); + } + + static adt::Result AddPass( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1); + if (args.at(0).template Has()) { + return self_val; + } + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_LET_CONST_REF(pass, args.at(0).template CastTo()) + << adt::errors::TypeError{std::string() + + "PirPassManager.add_pass() failed. the " + "argument 1 should be a PirPass (not a " + + axpr::GetTypeName(args.at(0)) + ")."}; + ADT_CHECK(pass->pir_pass != nullptr) + << adt::errors::TypeError{std::string() + "PirPass being used."}; + self->pir_pass_manager->AddPass(std::move(pass.shared_ptr()->pir_pass)); + return self_val; + } + + static adt::Result Run(const axpr::Value& self_val, + const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(self, self_val.template CastTo()); + ADT_LET_CONST_REF(program, + args.at(0).template CastTo()) + << adt::errors::TypeError{std::string() + + "PirPassManager.run() failed. the argument 1 " + "should be a PirProgram (not a " + + axpr::GetTypeName(args.at(0)) + ")."}; + self->pir_pass_manager->Run(program->pir_program.get()); + return adt::Nothing{}; + } +}; + +axpr::TypeImpl> +GetPirPassManagerClass() { + using Impl = PirPassManagerMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "PirPassManager", [&](const auto& Yield) { + Yield("__str__", &Impl::ToString); + Yield("add_pass", &Impl::AddPass); + Yield("run", &Impl::Run); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pir/pass_method_class.cc b/paddle/ap/src/paddle/pir/pass_method_class.cc new file mode 100644 index 00000000000000..a3718da95e6f0c --- /dev/null +++ b/paddle/ap/src/paddle/pir/pass_method_class.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/paddle/pir/pass_method_class.h" + +namespace ap::paddle { + +struct PirPassMethodClass { + using Self = ap::paddle::Pass; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const void* ptr = self->pir_pass.get(); + std::ostringstream ss; + ss << ""; + return ss.str(); + } +}; + +axpr::TypeImpl> GetPirPassClass() { + using Impl = PirPassMethodClass; + static auto cls(axpr::MakeBuiltinClass( + "PirPass", + [&](const auto& Yield) { Yield("__str__", &Impl::ToString); })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pir/pir_method_class.cc b/paddle/ap/src/paddle/pir/pir_method_class.cc new file mode 100644 index 00000000000000..f1d1b653c61f6c --- /dev/null +++ b/paddle/ap/src/paddle/pir/pir_method_class.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/paddle/pir/pir_method_class.h" +#include "paddle/ap/include/axpr/module_mgr.h" +#include "paddle/ap/include/paddle/pir_node.h" + +namespace ap::paddle { + +void ForceLinkPir() { + // Do nothing. +} + +template +void DefineMethods(Builder* m) { + m->Def("UndefinedPlace", &CreateUndefinedPlace); + m->Def("CPUPlace", &CreateCPUPlace); + m->Def("GPUPlace", &CreateGPUPlace); + m->Def("GPUPinnedPlace", &CreateGPUPinnedPlace); + m->Def("XPUPlace", &CreateXPUPlace); + m->Def("IPUPlace", &CreateIPUPlace); + m->Def("CustomPlace", &CreateCustomPlace); +#define DEF_MAKE_ATTRIBUTE(attr_type) \ + m->Def(attr_type::name(), &MakePirAttributeImpl::Call); + FOR_EACH_PIR_ATTRIBUTE_TYPE(DEF_MAKE_ATTRIBUTE); +#undef DEF_MAKE_ATTRIBUTE + +#define DEF_MAKE_TYPE(cls) m->Def(cls::name(), &MakePirTypeImpl::Call); + FOR_EACH_PIR_ALTERNATIVE_TYPE(DEF_MAKE_TYPE); +#undef DEF_MAKE_TYPE +} + +REGISTER_AP_BUILTIN_MODULE("pir", [](auto* m) { DefineMethods(m); }); + +axpr::TypeImpl> GetPirClass() { + static auto cls( + axpr::MakeBuiltinClass("pir", [&](const auto& Yield) { + Yield("UndefinedPlace", &CreateUndefinedPlace); + Yield("CPUPlace", &CreateCPUPlace); + Yield("GPUPlace", &CreateGPUPlace); + Yield("GPUPinnedPlace", &CreateGPUPinnedPlace); + Yield("XPUPlace", &CreateXPUPlace); + Yield("IPUPlace", &CreateIPUPlace); + Yield("CustomPlace", &CreateCustomPlace); +#define YIELD_MAKE_ATTRIBUTE(attr_type) \ + Yield(attr_type::name(), &MakePirAttributeImpl::Call); + FOR_EACH_PIR_ATTRIBUTE_TYPE(YIELD_MAKE_ATTRIBUTE); +#undef YIELD_MAKE_ATTRIBUTE + +#define YIELD_MAKE_TYPE(cls) Yield(cls::name(), &MakePirTypeImpl::Call); + FOR_EACH_PIR_ALTERNATIVE_TYPE(YIELD_MAKE_TYPE); +#undef YIELD_MAKE_TYPE + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pir/pir_node_matched_src_ptn_ctx_helper.cc b/paddle/ap/src/paddle/pir/pir_node_matched_src_ptn_ctx_helper.cc new file mode 100644 index 00000000000000..c637a9eb72316e --- /dev/null +++ b/paddle/ap/src/paddle/pir/pir_node_matched_src_ptn_ctx_helper.cc @@ -0,0 +1,430 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/pir/pir_node_matched_src_ptn_ctx_helper.h" +#include +#include "glog/logging.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/cps_interpreter.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/drr/builtin_frame_util.h" +#include "paddle/ap/include/drr/drr_node_descriptor.h" +#include "paddle/ap/include/drr/ir_op.h" +#include "paddle/ap/include/drr/ir_value.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/source_pattern_ctx.h" +#include "paddle/ap/include/drr/src_ptn_op_pattern_ctx_method_class.h" +#include "paddle/ap/include/drr/src_ptn_tensor_pattern_ctx_method_class.h" +#include "paddle/ap/include/drr/tags.h" +#include "paddle/ap/include/drr/value_method_class.h" +#include "paddle/ap/include/graph/node.h" +#include "paddle/ap/include/paddle/pir/attribute_method_class.h" +#include "paddle/ap/include/paddle/pir/packed_ir_op_inner_source_pattern_helper.h" +#include "paddle/ap/include/paddle/pir/pir_method_class.h" +#include "paddle/ap/include/paddle/pir/type_method_class.h" +#include "paddle/pir/include/core/block.h" +#include "paddle/pir/include/core/operation.h" + +namespace ap::paddle { + +namespace { + +adt::Result(const pir::Operation*)>> +MakeOpNameGetter(const pir::Block* block) { + using CacheT = std::map; + CacheT cache{}; + { + int i = 0; + for (auto& op : *block) { + ADT_CHECK(cache.emplace(&op, op.name() + "_" + std::to_string(i)).second); + ++i; + } + } + using RetFunc = + std::function(const pir::Operation*)>; + + RetFunc func = [cache = std::move(cache)]( + const pir::Operation* op) -> adt::Result { + const auto& iter = cache.find(op); + ADT_CHECK(iter != cache.end()); + return iter->second; + }; + return func; +} + +adt::Result(pir::Value)>> +MakeTensorNameGetter(const pir::Block* block) { + using CacheT = std::map; + CacheT cache{}; + { + int i = 0; + for (auto& op : *block) { + for (int j = 0; j < op.num_results(); ++j) { + pir::Value value = op.result(j); + const auto& name = + op.name() + "_" + std::to_string(i) + "_" + std::to_string(j); + ADT_CHECK(cache.emplace(value, name).second); + } + ++i; + } + } + { + int i = 0; + for (auto& op : *block) { + for (int j = 0; j < op.num_operands(); ++j) { + pir::Value value = op.operand_source(j); + if (cache.count(value) > 0) { + continue; + } + const auto& name = std::string() + "input_" + std::to_string(i++); + ADT_CHECK(cache.emplace(value, name).second); + } + } + } + using RetFunc = std::function(pir::Value)>; + RetFunc func = + [cache = std::move(cache)](pir::Value value) -> adt::Result { + const auto& iter = cache.find(value); + ADT_CHECK(iter != cache.end()); + return iter->second; + }; + return func; +} + +class SourcePatternCtxBuilder { + drr::SourcePatternCtx src_ptn_ctx_; + std::unique_ptr interpreter_; + + public: + SourcePatternCtxBuilder(const drr::SourcePatternCtx& src_ptn_ctx, + std::unique_ptr&& interpreter) + : src_ptn_ctx_(src_ptn_ctx), interpreter_(std::move(interpreter)) {} + + const drr::SourcePatternCtx& src_ptn_ctx() const { return src_ptn_ctx_; } + + adt::Result BuildNativeOp(const std::string& op_name, + const std::string& op_unique_name) { + static const axpr::Lambda func([]() { + axpr::LambdaExprBuilder lmd; + const auto& anf_expr = lmd.Lambda( + {"o", "op_name", "op_unique_name"}, + [&](axpr::LetContext& ctx) -> axpr::AnfExpr { + auto& native_op = + ctx.Var("o").Attr("ap_native_op").Call(ctx.Var("op_name")); + ctx.Var("o").SetAttr(ctx.Var("op_unique_name"), native_op); + return ctx.None(); + }); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + CHECK(core_expr.Has>()); + const auto& atomic = core_expr.Get>(); + CHECK(atomic.Has>()); + return atomic.Get>(); + }()); + const auto& op_pattern_ctx = drr::GetSrcPtnOpPatternCtxClass().New( + drr::SrcPtn(src_ptn_ctx_->op_pattern_ctx)); + ADT_RETURN_IF_ERR(interpreter_->Interpret(func, + {axpr::Value{op_pattern_ctx}, + axpr::Value{op_name}, + axpr::Value{op_unique_name}})); + return adt::Ok{}; + } + + adt::Result SetOpAttr(const std::string& op_unique_name, + const std::string& attr_name, + pir::Attribute attribute) { + static const axpr::Lambda func([]() { + axpr::LambdaExprBuilder lmd; + const auto& anf_expr = + lmd.Lambda({"o", "op_unique_name", "attr_name", "attr_val"}, + [&](axpr::LetContext& ctx) -> axpr::AnfExpr { + ctx.Var("o") + .Attr(ctx.Var("op_unique_name")) + .SetAttr(ctx.Var("attr_name"), ctx.Var("attr_val")); + return ctx.None(); + }); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + CHECK(core_expr.Has>()); + const auto& atomic = core_expr.Get>(); + CHECK(atomic.Has>()); + return atomic.Get>(); + }()); + const auto& op_pattern_ctx = drr::GetSrcPtnOpPatternCtxClass().New( + drr::SrcPtn(src_ptn_ctx_->op_pattern_ctx)); + const auto& attr_val = GetPirAttributeClass().New(attribute); + ADT_RETURN_IF_ERR(interpreter_->Interpret(func, + {axpr::Value{op_pattern_ctx}, + axpr::Value{op_unique_name}, + axpr::Value{attr_name}, + axpr::Value{attr_val}})); + return adt::Ok{}; + } + + adt::Result SetTensorType(const std::string& tensor_unique_name, + pir::Type type) { + static const axpr::Lambda func([]() { + axpr::LambdaExprBuilder lmd; + const auto& anf_expr = + lmd.Lambda({"t", "tensor_unique_name", "type_val"}, + [&](axpr::LetContext& ctx) -> axpr::AnfExpr { + ctx.Var("t") + .Attr(ctx.Var("tensor_unique_name")) + .SetAttr("type", ctx.Var("type_val")); + return ctx.None(); + }); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + CHECK(core_expr.Has>()); + const auto& atomic = core_expr.Get>(); + CHECK(atomic.Has>()); + return atomic.Get>(); + }()); + const auto& tensor_pattern_ctx = drr::GetSrcPtnTensorPatternCtxClass().New( + drr::SrcPtn(src_ptn_ctx_->tensor_pattern_ctx)); + auto GetType = [&]() -> axpr::Value { return GetPirTypeClass().New(type); }; + ADT_RETURN_IF_ERR(interpreter_->Interpret(func, + {axpr::Value{tensor_pattern_ctx}, + axpr::Value{tensor_unique_name}, + axpr::Value{GetType()}})); + return adt::Ok{}; + } + + adt::Result Connect( + const std::string& op_unique_name, + const std::vector>& input_tensor_names, + const std::vector& output_tensor_names) { + static const axpr::Lambda func([]() { + axpr::LambdaExprBuilder lmd; + const auto& anf_expr = lmd.Lambda( + {"o", + "t", + "op_unique_name", + "input_tensor_names_val", + "output_tensor_names_val"}, + [&](axpr::LetContext& ctx) -> axpr::AnfExpr { + const auto& get_or_create = + ctx.Var("t").Attr("get_or_create_tensor"); + auto& op = ctx.Var("o").Attr(ctx.Var("op_unique_name")); + op.Call(ctx.Var("map").Call(get_or_create, + ctx.Var("input_tensor_names_val")), + ctx.Var("map").Call(get_or_create, + ctx.Var("output_tensor_names_val"))); + return ctx.None(); + }); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + CHECK(core_expr.Has>()); + const auto& atomic = core_expr.Get>(); + CHECK(atomic.Has>()); + return atomic.Get>(); + }()); + const auto& op_pattern_ctx = drr::GetSrcPtnOpPatternCtxClass().New( + drr::SrcPtn(src_ptn_ctx_->op_pattern_ctx)); + const auto& tensor_pattern_ctx = drr::GetSrcPtnTensorPatternCtxClass().New( + drr::SrcPtn(src_ptn_ctx_->tensor_pattern_ctx)); + const auto& input_tensor_names_val = OptStrsToList(input_tensor_names); + const auto& output_tensor_names_val = StrsToList(output_tensor_names); + ADT_RETURN_IF_ERR( + interpreter_->Interpret(func, + { + axpr::Value{op_pattern_ctx}, + axpr::Value{tensor_pattern_ctx}, + axpr::Value{op_unique_name}, + axpr::Value{input_tensor_names_val}, + axpr::Value{output_tensor_names_val}, + })); + return adt::Ok{}; + } + + private: + adt::List StrsToList(const std::vector& strs) { + adt::List ret; + ret->reserve(strs.size()); + for (const auto& str : strs) { + ret->emplace_back(axpr::Value{str}); + } + return ret; + } + + adt::List OptStrsToList( + const std::vector>& strs) { + adt::List ret; + ret->reserve(strs.size()); + for (const auto& str : strs) { + if (str.has_value()) { + ret->emplace_back(axpr::Value{str.value()}); + } else { + ret->emplace_back(adt::Nothing{}); + } + } + return ret; + } +}; + +std::unique_ptr MakeSourcePatternCtxBuilder( + const std::shared_ptr& drr_ctx) { + auto node_arena = std::make_shared>(); + drr::SourcePatternCtx src_ptn_ctx{ + node_arena, + drr::OpPatternCtx{ + node_arena, std::map{}, drr_ctx}, + drr::TensorPatternCtx{ + node_arena, std::map{}, drr_ctx}}; + const auto& builtin_frame = + ap::drr::MakeBuiltinFrameAttrMap([&](const auto&) {}); + auto interpreter = std::make_unique( + builtin_frame, drr_ctx->circlable_ref_list); + return std::make_unique(src_ptn_ctx, + std::move(interpreter)); +} + +adt::Result InitSrcPtnCtxNativeIrOps( + SourcePatternCtxBuilder* builder, + pir::Block* block, + const std::function(const pir::Operation*)>& + GetOpName) { + for (auto& op : *block) { + auto* op_ptr = &op; + ADT_LET_CONST_REF(op_unique_name, GetOpName(op_ptr)); + ADT_RETURN_IF_ERR(builder->BuildNativeOp(op.name(), op_unique_name)); + for (const auto& [attr_name, attr_val] : op.attributes()) { + ADT_RETURN_IF_ERR( + builder->SetOpAttr(op_unique_name, attr_name, attr_val)); + } + } + return adt::Ok{}; +} + +adt::Result InitSrcPtnCtxNativeIrValues( + SourcePatternCtxBuilder* builder, + pir::Block* block, + const std::function(pir::Value)>& GetTensorName) { + std::unordered_set inited; + auto InitType = [&](pir::Value value) -> adt::Result { + if (inited.insert(value).second) { + ADT_LET_CONST_REF(tensor_name, GetTensorName(value)); + ADT_RETURN_IF_ERR(builder->SetTensorType(tensor_name, value.type())); + } + return adt::Ok{}; + }; + for (auto& op : *block) { + for (int i = 0; i < op.num_operands(); ++i) { + if (op.operand_source(i)) { + ADT_RETURN_IF_ERR(InitType(op.operand_source(i))); + } + } + for (int i = 0; i < op.num_results(); ++i) { + ADT_RETURN_IF_ERR(InitType(op.result(i))); + } + } + return adt::Ok{}; +} + +adt::Result InitSrcPtnCtxConnections( + SourcePatternCtxBuilder* builder, + pir::Block* block, + const std::function(const pir::Operation*)>& + GetOpName, + const std::function(pir::Value)>& GetTensorName) { + for (auto& op : *block) { + auto* op_ptr = &op; + ADT_LET_CONST_REF(op_unique_name, GetOpName(op_ptr)); + std::vector> input_tensor_names{}; + input_tensor_names.reserve(op.num_operands()); + for (int i = 0; i < op.num_operands(); ++i) { + if (op.operand_source(i)) { + ADT_LET_CONST_REF(input_name, GetTensorName(op.operand_source(i))); + input_tensor_names.emplace_back(input_name); + } else { + input_tensor_names.emplace_back(std::nullopt); + } + } + std::vector output_tensor_names{}; + output_tensor_names.reserve(op.num_results()); + for (int i = 0; i < op.num_results(); ++i) { + ADT_LET_CONST_REF(output_name, GetTensorName(op.result(i))); + output_tensor_names.emplace_back(output_name); + } + ADT_RETURN_IF_ERR(builder->Connect( + op_unique_name, input_tensor_names, output_tensor_names)); + } + return adt::Ok{}; +} + +} // namespace + +adt::Result> +PirNodeMatchedSrcPtnCtxHelper::MakeInnerMatchedSrcPtnCtxHelper( + const drr::PackedIrOp& drr_packed_ir_op) { + ADT_LET_CONST_REF(pir_node, + match_ctx_->GetSoleBigGraphNode(drr_packed_ir_op->node)); + ADT_LET_CONST_REF(pir_packed_ir_op, pir_node.template TryGet()); + ADT_LET_CONST_REF( + op_pattern_ctx, + adt::WeakPtrLock(drr_packed_ir_op->op_declare->op_pattern_ctx)); + ADT_LET_CONST_REF(drr_ctx, adt::WeakPtrLock(op_pattern_ctx->drr_ctx)); + ADT_LET_CONST_REF( + src_ptn_ctx, + ConvertBlockToSrcPtnCtx(pir_packed_ir_op.fusion_op.block(), drr_ctx)); + PackedIrOpInnerSourcePatternHelper inner_src_ptn_ctx_helper{}; + ADT_LET_CONST_REF( + opt_match_ctx, + inner_src_ptn_ctx_helper.Match(pir_packed_ir_op, src_ptn_ctx)); + ADT_CHECK(opt_match_ctx.has_value()); + std::shared_ptr + matched_src_ptn_ctx_helper = + std::make_shared( + src_ptn_ctx, opt_match_ctx.value()); + return matched_src_ptn_ctx_helper; +} + +adt::Result PirNodeMatchedSrcPtnCtxHelper::VisitNativeIrOpAttr( + const drr::NativeIrOp& drr_native_ir_op, + const std::function(const std::string& attr_name, + const axpr::Value& attr_val)>& + DoEachAttr) { + ADT_LET_CONST_REF(pir_node, + match_ctx_->GetSoleBigGraphNode(drr_native_ir_op->node)); + ADT_LET_CONST_REF(pir_native_ir_op, pir_node.template TryGet()); + for (const auto& [attr_name, attr] : pir_native_ir_op.op->attributes()) { + if (!attr) continue; + if (attr_name == "op_callstack") continue; + if (attr_name == "sym_shape_str") continue; + const auto& attr_val = GetPirAttributeClass().New(attr); + ADT_RETURN_IF_ERR(DoEachAttr(attr_name, attr_val)); + } + return adt::Ok{}; +} + +adt::Result PirNodeMatchedSrcPtnCtxHelper::GetNativeIrValueType( + const drr::NativeIrValue& native_ir_value) { + ADT_LET_CONST_REF(pir_node, + match_ctx_->GetSoleBigGraphNode(native_ir_value->node)); + ADT_LET_CONST_REF(pir_native_ir_value, + pir_node.template TryGet()); + return GetPirTypeClass().New(pir_native_ir_value.value.type()); +} + +adt::Result +PirNodeMatchedSrcPtnCtxHelper::ConvertBlockToSrcPtnCtx( + pir::Block* block, const std::shared_ptr& drr_ctx) { + ADT_LET_CONST_REF(GetOpName, MakeOpNameGetter(block)); + ADT_LET_CONST_REF(GetTensorName, MakeTensorNameGetter(block)); + auto builder = MakeSourcePatternCtxBuilder(drr_ctx); + ADT_RETURN_IF_ERR(InitSrcPtnCtxNativeIrOps(builder.get(), block, GetOpName)); + ADT_RETURN_IF_ERR( + InitSrcPtnCtxNativeIrValues(builder.get(), block, GetTensorName)); + ADT_RETURN_IF_ERR( + InitSrcPtnCtxConnections(builder.get(), block, GetOpName, GetTensorName)); + return builder->src_ptn_ctx(); +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pir/pir_to_anf_expr_helper.cc b/paddle/ap/src/paddle/pir/pir_to_anf_expr_helper.cc new file mode 100644 index 00000000000000..cc17457c94c70a --- /dev/null +++ b/paddle/ap/src/paddle/pir/pir_to_anf_expr_helper.cc @@ -0,0 +1,707 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/pir/pir_to_anf_expr_helper.h" +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/drr/value_method_class.h" +#include "paddle/ap/include/paddle/phi/scalar_helper.h" + +namespace ap::paddle { + +namespace { + +template +struct TypeToAnfExprConverter; + +template <> +struct TypeToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir").Attr(NullType::name()).Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::pir::VectorType> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + ADT_CHECK(type.template isa<::pir::VectorType>()); + auto vec_type = type.template dyn_cast<::pir::VectorType>(); + std::vector args; + args.reserve(vec_type.size()); + for (const auto& elt_type : vec_type.data()) { + ADT_LET_CONST_REF( + elt_anf_expr, + PirToAnfExprHelper{}.ConvertPirTypeToAnfExpr(ctx, elt_type)); + args.emplace_back(elt_anf_expr); + } + return ctx->Var("pir") + .Attr(::pir::VectorType::name()) + .Call(ctx->Var(axpr::kBuiltinList()).Apply(args)); + } +}; + +adt::Result ConvertToDataLayoutAnfExpr( + axpr::LetContext* ctx, const ::common::DataLayout& data_layout) { + try { + const auto& data_layout_str = ::common::DataLayoutToString(data_layout); + return ctx->String(data_layout_str); + } catch (const std::exception& e) { + return adt::errors::ValueError{e.what()}; + } +} + +template <> +struct TypeToAnfExprConverter<::pir::DenseTensorType> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + ADT_CHECK(type.template isa<::pir::DenseTensorType>()); + auto dense_tensor_type = type.template dyn_cast<::pir::DenseTensorType>(); + + // dtype + ADT_LET_CONST_REF(dtype_anf_expr, + PirToAnfExprHelper{}.ConvertPirTypeToAnfExpr( + ctx, dense_tensor_type.dtype())); + + // dims + std::vector dim_elts; + const auto& dims = dense_tensor_type.dims(); + for (int i = 0; i < dims.size(); ++i) { + dim_elts.push_back(ctx->Int64(dims.at(i))); + } + const auto& dims_anf_expr = ctx->Var(axpr::kBuiltinList()).Apply(dim_elts); + + // data layout + ADT_LET_CONST_REF( + data_layout_anf_expr, + ConvertToDataLayoutAnfExpr(ctx, dense_tensor_type.data_layout())); + return ctx->Var("pir") + .Attr(::pir::DenseTensorType::name()) + .Call(dtype_anf_expr, dims_anf_expr, data_layout_anf_expr); + } +}; + +template <> +struct TypeToAnfExprConverter<::pir::BFloat16Type> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir").Attr(::pir::BFloat16Type::name()).Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::pir::Float16Type> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir").Attr(::pir::Float16Type::name()).Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::pir::Float32Type> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir").Attr(::pir::Float32Type::name()).Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::pir::Float64Type> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir").Attr(::pir::Float64Type::name()).Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::pir::Int8Type> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir").Attr(::pir::Int8Type::name()).Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::pir::UInt8Type> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir").Attr(::pir::UInt8Type::name()).Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::pir::Int16Type> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir").Attr(::pir::Int16Type::name()).Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::pir::Int32Type> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir").Attr(::pir::Int32Type::name()).Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::pir::Int64Type> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir").Attr(::pir::Int64Type::name()).Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::pir::IndexType> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir").Attr(::pir::IndexType::name()).Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::pir::BoolType> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir").Attr(::pir::BoolType::name()).Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::pir::Complex64Type> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir").Attr(::pir::Complex64Type::name()).Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::pir::Complex128Type> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir").Attr(::pir::Complex128Type::name()).Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::paddle::dialect::SelectedRowsType> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return ctx->Var("pir") + .Attr(::paddle::dialect::SelectedRowsType::name()) + .Call(); + } +}; + +template <> +struct TypeToAnfExprConverter<::paddle::dialect::DenseTensorArrayType> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + ADT_CHECK(type.template isa<::paddle::dialect::DenseTensorArrayType>()); + auto dense_tensor_array_type = + type.template dyn_cast<::paddle::dialect::DenseTensorArrayType>(); + + // dtype + ADT_LET_CONST_REF(dtype_anf_expr, + PirToAnfExprHelper{}.ConvertPirTypeToAnfExpr( + ctx, dense_tensor_array_type.dtype())); + + // dims + std::vector dim_elts; + const auto& dims = dense_tensor_array_type.dims(); + for (int i = 0; i < dims.size(); ++i) { + dim_elts.push_back(ctx->Int64(dims.at(i))); + } + const auto& dims_anf_expr = ctx->Var(axpr::kBuiltinList()).Apply(dim_elts); + + // data layout + ADT_LET_CONST_REF( + data_layout_anf_expr, + ConvertToDataLayoutAnfExpr(ctx, dense_tensor_array_type.data_layout())); + return ctx->Var("pir") + .Attr(::paddle::dialect::DenseTensorArrayType::name()) + .Call(dtype_anf_expr, dims_anf_expr, data_layout_anf_expr); + } +}; + +template <> +struct TypeToAnfExprConverter<::paddle::dialect::SparseCooTensorType> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return adt::errors::NotImplementedError{ + std::string() + ::paddle::dialect::SparseCooTensorType::name() + + "() is not implemented"}; + } +}; + +template <> +struct TypeToAnfExprConverter<::paddle::dialect::SparseCsrTensorType> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return adt::errors::NotImplementedError{ + std::string() + ::paddle::dialect::SparseCsrTensorType::name() + + "() is not implemented"}; + } +}; + +template <> +struct TypeToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Type type) { + return adt::errors::NotImplementedError{ + std::string() + UnclassifiedType::name() + "() is not implemented"}; + } +}; + +template +struct AttrToAnfExprConverter; + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa()); + const auto& attr_impl = attr.template dyn_cast(); + const auto& attr_data_val = ctx->Bool(attr_impl.data()); + return ctx->Var("pir").Attr(pir::BoolAttribute::name()).Call(attr_data_val); + } +}; + +adt::Result ConvertToComplex64AnfExpr( + axpr::LetContext* ctx, const axpr::complex64& attr_data) { + const auto& real = ctx->Var("DataValue") + .Attr("float32") + .Call(ctx->String(std::to_string(attr_data.real))); + const auto& imag = ctx->Var("DataValue") + .Attr("float32") + .Call(ctx->String(std::to_string(attr_data.imag))); + const auto& attr_data_val = + ctx->Var("DataValue").Attr("complex64").Call(real, imag); + return attr_data_val; +} + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa()); + const auto& attr_impl = attr.template dyn_cast(); + const auto& attr_data = attr_impl.data(); + ADT_LET_CONST_REF(attr_data_val, ConvertToComplex64AnfExpr(ctx, attr_data)); + return ctx->Var("pir") + .Attr(pir::Complex64Attribute::name()) + .Call(attr_data_val); + } +}; + +adt::Result ConvertToComplex128AnfExpr( + axpr::LetContext* ctx, const axpr::complex128& attr_data) { + const auto& real = ctx->Var("DataValue") + .Attr("float64") + .Call(ctx->String(std::to_string(attr_data.real))); + const auto& imag = ctx->Var("DataValue") + .Attr("float64") + .Call(ctx->String(std::to_string(attr_data.imag))); + const auto& attr_data_val = + ctx->Var("DataValue").Attr("complex128").Call(real, imag); + return attr_data_val; +} + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa()); + const auto& attr_impl = attr.template dyn_cast(); + const auto& attr_data = attr_impl.data(); + ADT_LET_CONST_REF(attr_data_val, + ConvertToComplex128AnfExpr(ctx, attr_data)); + return ctx->Var("pir") + .Attr(pir::Complex128Attribute::name()) + .Call(attr_data_val); + } +}; + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa()); + const auto& attr_impl = attr.template dyn_cast(); + const auto& attr_str = ctx->String(std::to_string(attr_impl.data())); + const auto& attr_data_val = + ctx->Var("DataValue").Attr("float32").Call(attr_str); + return ctx->Var("pir") + .Attr(pir::FloatAttribute::name()) + .Call(attr_data_val); + } +}; + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa()); + const auto& attr_impl = attr.template dyn_cast(); + const auto& attr_str = ctx->String(std::to_string(attr_impl.data())); + const auto& attr_data_val = + ctx->Var("DataValue").Attr("float64").Call(attr_str); + return ctx->Var("pir") + .Attr(pir::DoubleAttribute::name()) + .Call(attr_data_val); + } +}; + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa()); + const auto& attr_impl = attr.template dyn_cast(); + const auto& attr_str = ctx->String(std::to_string(attr_impl.data())); + const auto& attr_data_val = + ctx->Var("DataValue").Attr("int32").Call(attr_str); + return ctx->Var("pir") + .Attr(pir::Int32Attribute::name()) + .Call(attr_data_val); + } +}; + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa()); + const auto& attr_impl = attr.template dyn_cast(); + const auto& attr_str = ctx->String(std::to_string(attr_impl.data())); + const auto& attr_data_val = + ctx->Var("DataValue").Attr("index").Call(attr_str); + return ctx->Var("pir") + .Attr(pir::IndexAttribute::name()) + .Call(attr_data_val); + } +}; + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa()); + const auto& attr_impl = attr.template dyn_cast(); + const auto& attr_str = ctx->String(std::to_string(attr_impl.data())); + const auto& attr_data_val = + ctx->Var("DataValue").Attr("int64").Call(attr_str); + return ctx->Var("pir") + .Attr(pir::Int64Attribute::name()) + .Call(attr_data_val); + } +}; + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa()); + const auto& attr_impl = attr.template dyn_cast(); + const auto& attr_str = ctx->String([&] { + std::ostringstream ss; + ss << attr_impl.data(); + return ss.str(); + }()); + const auto& attr_data_val = + ctx->Var("PointerValue").Attr("void_ptr").Call(attr_str); + return ctx->Var("pir") + .Attr(pir::PointerAttribute::name()) + .Call(attr_data_val); + } +}; + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa()); + const auto& attr_impl = attr.template dyn_cast(); + ADT_LET_CONST_REF( + attr_data_val, + PirToAnfExprHelper{}.ConvertPirTypeToAnfExpr(ctx, attr_impl.data())); + return ctx->Var("pir").Attr(pir::TypeAttribute::name()).Call(attr_data_val); + } +}; + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa()); + const auto& attr_impl = attr.template dyn_cast(); + return ctx->Var("pir") + .Attr(pir::StrAttribute::name()) + .Call(attr_impl.AsString()); + } +}; + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa()); + const auto& attr_impl = attr.template dyn_cast(); + std::vector elts_anf_exprs{}; + const auto& data = attr_impl.AsVector(); + elts_anf_exprs.reserve(data.size()); + for (const auto& elt : data) { + ADT_LET_CONST_REF(elt_anf_expr, + PirToAnfExprHelper{}.ConvertPirAttrToAnfExpr(ctx, elt)); + elts_anf_exprs.emplace_back(elt_anf_expr); + } + return ctx->Var("pir") + .Attr(pir::ArrayAttribute::name()) + .Call(ctx->Var(axpr::kBuiltinList()).Apply(elts_anf_exprs)); + } +}; + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa()); + const auto& attr_impl = attr.template dyn_cast(); + return ctx->Var("pir") + .Attr(pir::TensorNameAttribute::name()) + .Call(attr_impl.data()); + } +}; + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + return adt::errors::NotImplementedError{ + std::string() + "pir." + pir::shape::SymbolAttribute::name() + + "() not implemented"}; + } +}; + +template <> +struct AttrToAnfExprConverter<::paddle::dialect::KernelAttribute> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + return adt::errors::NotImplementedError{ + std::string() + "pir." + ::paddle::dialect::KernelAttribute::name() + + "() not implemented"}; + } +}; + +template <> +struct AttrToAnfExprConverter<::paddle::dialect::IntArrayAttribute> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa<::paddle::dialect::IntArrayAttribute>()); + const auto& attr_impl = + attr.template dyn_cast<::paddle::dialect::IntArrayAttribute>(); + std::vector elts_anf_exprs{}; + const auto& phi_int_array = attr_impl.data(); + const auto& data = phi_int_array.GetData(); + elts_anf_exprs.reserve(data.size()); + for (const auto& elt : data) { + const auto& elt_anf_expr = ctx->Int64(elt); + elts_anf_exprs.emplace_back(elt_anf_expr); + } + return ctx->Var("pir") + .Attr(::paddle::dialect::IntArrayAttribute::name()) + .Call(ctx->Var(axpr::kBuiltinList()).Apply(elts_anf_exprs)); + } +}; + +adt::Result ConvertToDataValueAnfExpr( + axpr::LetContext* ctx, const phi::Scalar& scalar) { + ADT_LET_CONST_REF(data_value, ScalarHelper{}.ConvertToDataValue(scalar)); + using RetT = adt::Result; + return data_value.Match( + [&](const axpr::complex64& impl) -> RetT { + return ConvertToComplex64AnfExpr(ctx, impl); + }, + [&](const axpr::complex128& impl) -> RetT { + return ConvertToComplex128AnfExpr(ctx, impl); + }, + [&](const auto& impl) -> RetT { + try { + const auto& data_type = data_value.GetType(); + const auto& data_val_str = ctx->String(scalar.ToRawString()); + return ctx->Var("DataValue") + .Attr(data_type.Name()) + .Call(data_val_str); + } catch (const std::exception& e) { + return adt::errors::RuntimeError{e.what()}; + } + }); +} + +template <> +struct AttrToAnfExprConverter<::paddle::dialect::ScalarAttribute> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa<::paddle::dialect::ScalarAttribute>()); + const auto& attr_impl = + attr.template dyn_cast<::paddle::dialect::ScalarAttribute>(); + ADT_LET_CONST_REF(data_value_anf_expr, + ConvertToDataValueAnfExpr(ctx, attr_impl.data())); + return ctx->Var("pir") + .Attr(::paddle::dialect::ScalarAttribute::name()) + .Call(data_value_anf_expr); + } +}; + +template <> +struct AttrToAnfExprConverter<::paddle::dialect::DataTypeAttribute> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa<::paddle::dialect::DataTypeAttribute>()); + const auto& attr_impl = + attr.template dyn_cast<::paddle::dialect::DataTypeAttribute>(); + ADT_LET_CONST_REF(data_type, + axpr::GetDataTypeFromPhiDataType(attr_impl.data())); + const auto& data_type_anf_expr = + ctx->Var("DataType").Attr(data_type.Name()); + return ctx->Var("pir") + .Attr(::paddle::dialect::DataTypeAttribute::name()) + .Call(data_type_anf_expr); + } +}; + +adt::Result ConvertToPlaceAnfExpr(axpr::LetContext* ctx, + const phi::Place& place) { + if (place.GetType() == phi::AllocationType::UNDEFINED) { + return ctx->Var("pir").Attr("UndefinedPlace").Call(); + } else if (place.GetType() == phi::AllocationType::CPU) { + return ctx->Var("pir").Attr("CPUPlace").Call(); + } else if (place.GetType() == phi::AllocationType::GPU) { + const auto& device_id = ctx->Int64(place.GetDeviceId()); + return ctx->Var("pir").Attr("GPUPlace").Call(device_id); + } else if (place.GetType() == phi::AllocationType::GPUPINNED) { + return ctx->Var("pir").Attr("GPUPinnedPlace").Call(); + } else if (place.GetType() == phi::AllocationType::XPU) { + const auto& device_id = ctx->Int64(place.GetDeviceId()); + return ctx->Var("pir").Attr("XPUPlace").Call(device_id); + } else if (place.GetType() == phi::AllocationType::IPU) { + const auto& device_id = ctx->Int64(place.GetDeviceId()); + return ctx->Var("pir").Attr("IPUPlace").Call(device_id); + } else if (place.GetType() == phi::AllocationType::CUSTOM) { + const auto& device_type = ctx->String(place.GetDeviceType()); + const auto& device_id = ctx->Int64(place.GetDeviceId()); + return ctx->Var("pir").Attr("CustomPlace").Call(device_type, device_id); + } + return adt::errors::TypeError{ + "ConvertToPlaceAnfExpr() failed. invalid place"}; +} + +template <> +struct AttrToAnfExprConverter<::paddle::dialect::PlaceAttribute> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa<::paddle::dialect::PlaceAttribute>()); + const auto& attr_impl = + attr.template dyn_cast<::paddle::dialect::PlaceAttribute>(); + ADT_LET_CONST_REF(place_anf_expr, + ConvertToPlaceAnfExpr(ctx, attr_impl.data())); + return ctx->Var("pir") + .Attr(::paddle::dialect::PlaceAttribute::name()) + .Call(place_anf_expr); + } +}; + +template <> +struct AttrToAnfExprConverter<::paddle::dialect::DataLayoutAttribute> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + ADT_CHECK(attr.template isa<::paddle::dialect::DataLayoutAttribute>()); + const auto& attr_impl = + attr.template dyn_cast<::paddle::dialect::DataLayoutAttribute>(); + ADT_LET_CONST_REF(data_layout_anf_expr, + ConvertToDataLayoutAnfExpr(ctx, attr_impl.data())); + return ctx->Var("pir") + .Attr(::paddle::dialect::DataLayoutAttribute::name()) + .Call(data_layout_anf_expr); + } +}; + +template <> +struct AttrToAnfExprConverter<::cinn::dialect::GroupInfoAttribute> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + return adt::errors::NotImplementedError{ + std::string() + "pir." + ::cinn::dialect::GroupInfoAttribute::name() + + "() is not implemneted"}; + } +}; + +template <> +struct AttrToAnfExprConverter<::cinn::dialect::CINNKernelInfoAttribute> { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + return adt::errors::NotImplementedError{ + std::string() + "pir." + + ::cinn::dialect::CINNKernelInfoAttribute::name() + + "() is not implemneted"}; + } +}; + +template <> +struct AttrToAnfExprConverter { + static adt::Result Call(axpr::LetContext* ctx, + pir::Attribute attr) { + return adt::errors::NotImplementedError{std::string() + "pir." + + UnclassifiedAttribute::name() + + "() is not implemneted"}; + } +}; + +} // namespace + +adt::Result PirToAnfExprHelper::ConvertTypeToAnfExpr( + axpr::LetContext* ctx, axpr::Value type) { + ADT_LET_CONST_REF(pir_type, type.template CastTo()); + return ConvertPirTypeToAnfExpr(ctx, pir_type); +} + +adt::Result PirToAnfExprHelper::ConvertPirTypeToAnfExpr( + axpr::LetContext* ctx, pir::Type type) { + const auto& type_id = GetTypeAdtTypeId(type); + using RetT = adt::Result; + return type_id.Match([&](const auto& impl) -> RetT { + using T = typename std::decay_t::type; + return TypeToAnfExprConverter::Call(ctx, type); + }); +} + +adt::Result PirToAnfExprHelper::ConvertAttrToAnfExpr( + axpr::LetContext* ctx, axpr::Value attr) { + ADT_LET_CONST_REF(pir_attr, attr.template CastTo()); + return ConvertPirAttrToAnfExpr(ctx, pir_attr); +} + +adt::Result PirToAnfExprHelper::ConvertPirAttrToAnfExpr( + axpr::LetContext* ctx, pir::Attribute attr) { + const auto& attr_id = GetAttrAdtTypeId(attr); + using RetT = adt::Result; + return attr_id.Match([&](const auto& impl) -> RetT { + using T = typename std::decay_t::type; + return AttrToAnfExprConverter::Call(ctx, attr); + }); +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pir/program_method_class.cc b/paddle/ap/src/paddle/pir/program_method_class.cc new file mode 100644 index 00000000000000..4c5276bbb5a5d9 --- /dev/null +++ b/paddle/ap/src/paddle/pir/program_method_class.cc @@ -0,0 +1,210 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/paddle/pir/program_method_class.h" +#include "paddle/ap/include/axpr/dim_expr_method_class.h" +#include "paddle/ap/include/paddle/pir/attribute_method_class.h" +#include "paddle/ap/include/paddle/pir/type_method_class.h" +#include "paddle/ap/include/paddle/pir_node.h" + +namespace ap::paddle { + +struct PirProgramMethodClass { + using This = PirProgramMethodClass; + using Self = Program; + + static adt::Result ToString( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + std::ostringstream ss; + ADT_LET_CONST_REF(self, self_val.template CastTo()); + pir::IrPrinter(ss).PrintProgram(self->pir_program.get()); + return ss.str(); + } + + static adt::Result Empty(const axpr::Value& self_val, + const std::vector& args) { + ADT_CHECK(args.size() == 0); + std::ostringstream ss; + ADT_LET_CONST_REF(self, self_val.template CastTo()); + return self->pir_program->block()->size() == 0; + } + + static adt::Result CopyToConstProgramData( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + std::ostringstream ss; + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::vector values{}; + std::unordered_map value2index{}; + std::vector ops{}; + std::unordered_map op2index{}; + for (const auto& op : *self->pir_program->block()) { + for (int i = 0; i < op.num_operands(); ++i) { + if (value2index.emplace(op.operand_source(i), values.size()).second) { + values.push_back(op.operand_source(i)); + } + } + op2index[&op] = ops.size(); + ops.push_back(&op); + for (int i = 0; i < op.num_results(); ++i) { + if (value2index.emplace(op.result(i), values.size()).second) { + values.push_back(op.result(i)); + } + } + } + axpr::AttrMap attr_map; + ADT_LET_CONST_REF(value_data, This{}.ConvertToValues(values, op2index)); + attr_map->Set("values", value_data); + ADT_LET_CONST_REF(op_data, This{}.ConvertToOps(ops, value2index)); + attr_map->Set("ops", op_data); + return attr_map; + } + + static adt::Result Clone(const axpr::Value& self_val, + const std::vector& args) { + ADT_CHECK(args.size() == 0); + ADT_LET_CONST_REF(self, self_val.template CastTo()); + pir::IrMapping ir_mapping; + auto new_program = self->pir_program->Clone(ir_mapping); + ADT_RETURN_IF_ERR(This{}.CloneSymbolicShapes( + new_program.get(), self->pir_program.get(), ir_mapping)); + Program ap_program{new_program}; + return GetPirProgramClass().New(ap_program); + } + + adt::Result CloneSymbolicShapes(pir::Program* new_program, + pir::Program* old_program, + const pir::IrMapping& ir_mapping) { + auto* new_shape_analysis = + &::pir::ShapeAnalysisManager::Instance().Get(new_program); + auto* old_shape_analysis = + &::pir::ShapeAnalysisManager::Instance().Get(old_program); + for (const auto& [old_value, new_value] : ir_mapping.GetMap()) { + new_shape_analysis->SetShapeOrDataForValue( + new_value, old_shape_analysis->GetShapeOrDataForValue(old_value)); + } + return adt::Ok{}; + } + + adt::Result> ConvertToOps( + const std::vector& ops, + const std::unordered_map& value2index) { + adt::List ret; + ret->reserve(ops.size()); + int64_t op_index = 0; + for (const auto* op : ops) { + ADT_LET_CONST_REF(op_data, ConvertToOpData(op_index++, op, value2index)); + ret->emplace_back(op_data); + } + return ret; + } + + adt::Result ConvertToOpData( + int64_t op_index, + const pir::Operation* op, + const std::unordered_map& value2index) { + axpr::AttrMap attr_map; + attr_map->Set("op_index", op_index); + attr_map->Set("op_name", op->name()); + { + adt::List input_indexes; + input_indexes->reserve(op->num_operands()); + for (int i = 0; i < op->num_operands(); ++i) { + const auto& index_iter = value2index.find(op->operand_source(i)); + ADT_CHECK(index_iter != value2index.end()); + input_indexes->push_back(index_iter->second); + } + attr_map->Set("input_value_indexes", input_indexes); + } + { + adt::List output_indexes; + output_indexes->reserve(op->num_results()); + for (int i = 0; i < op->num_results(); ++i) { + const auto& index_iter = value2index.find(op->result(i)); + ADT_CHECK(index_iter != value2index.end()); + output_indexes->push_back(index_iter->second); + } + attr_map->Set("output_value_indexes", output_indexes); + } + { + axpr::AttrMap op_attributes; + for (const auto& [attr_name, attr_val] : op->attributes()) { + op_attributes->Set(attr_name, GetPirAttributeClass().New(attr_val)); + } + attr_map->Set("attributes", op_attributes); + } + return attr_map; + } + + adt::Result> ConvertToValues( + const std::vector& values, + const std::unordered_map& op2index) { + adt::List ret; + ret->reserve(values.size()); + int64_t value_index = 0; + for (pir::Value value : values) { + ADT_LET_CONST_REF(value_data, + ConvertToValueData(value_index++, value, op2index)); + ret->emplace_back(value_data); + } + return ret; + } + + adt::Result ConvertToValueData( + int64_t value_index, + const pir::Value& value, + const std::unordered_map& op2index) { + if (!value) return adt::Nothing{}; + axpr::AttrMap attr_map; + attr_map->Set("value_index", value_index); + ADT_CHECK(value.defining_op() != nullptr); + const auto& index_iter = op2index.find(value.defining_op()); + ADT_CHECK(index_iter != op2index.end()); + attr_map->Set("defining_op_index", index_iter->second); + attr_map->Set("type", GetPirTypeClass().New(value.type())); + ADT_LET_CONST_REF(symbolic_shape, GetShape(value)); + attr_map->Set("symbolic_shape", symbolic_shape); + return attr_map; + } + + adt::Result GetShape(pir::Value value) { + NativeIrValue ir_value{value}; + ADT_LET_CONST_REF(shape_ptr, ir_value.GetShapeDimExprsPtr()); + adt::List lst; + lst->reserve(shape_ptr->size()); + for (const auto& dim_expr : *shape_ptr) { + axpr::BuiltinClassInstance instance{ + axpr::GetDimExprClass(), dim_expr}; + lst->emplace_back(instance); + } + return lst; + } +}; + +axpr::TypeImpl> GetPirProgramClass() { + using Impl = PirProgramMethodClass; + static auto cls( + axpr::MakeBuiltinClass("PirProgram", [&](const auto& Yield) { + Yield("__str__", &Impl::ToString); + Yield("empty", &Impl::Empty); + Yield("copy_to_const_program_data", &Impl::CopyToConstProgramData); + Yield("clone", &Impl::Clone); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pir/shape_or_data_method_class.cc b/paddle/ap/src/paddle/pir/shape_or_data_method_class.cc new file mode 100644 index 00000000000000..10b75ba7e8396a --- /dev/null +++ b/paddle/ap/src/paddle/pir/shape_or_data_method_class.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/callable_helper.h" +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/paddle/pir/type_adt_type_id.h" +#include "paddle/ap/include/paddle/pir/type_method_class.h" + +namespace ap::paddle { + +adt::Result PirShapeOrDataString( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF(self, + self_val.template CastTo()); + std::ostringstream ss; + ss << self; + return ss.str(); +} + +axpr::TypeImpl> +GetPirShapeOrDataClass() { + static auto cls(axpr::MakeBuiltinClass( + "PirShapeOrData", + [&](const auto& Yield) { Yield("__str__", &PirShapeOrDataString); })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pir/type_method_class.cc b/paddle/ap/src/paddle/pir/type_method_class.cc new file mode 100644 index 00000000000000..166a2df9ddb77b --- /dev/null +++ b/paddle/ap/src/paddle/pir/type_method_class.cc @@ -0,0 +1,506 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/paddle/pir/type_method_class.h" +#include "paddle/ap/include/axpr/callable_helper.h" +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/paddle/pir/type_adt_type_id.h" + +namespace ap::paddle { + +adt::Result PirTypeString(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + std::ostringstream ss; + ss << self; + return ss.str(); +} + +struct PirTypeGetType { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const auto& type_id = GetTypeAdtTypeId(self); + return type_id.Match([&](const auto& impl) -> std::string { + using T = typename std::decay_t::type; + return T::name(); + }); + } +}; + +struct ConvertToDtype { + static adt::Result Call(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + try { + auto phi_type = ::paddle::dialect::TransToPhiDataType(self); + ADT_LET_CONST_REF(dtype, axpr::GetDataTypeFromPhiDataType(phi_type)); + return dtype; + } catch (const std::exception& e) { + return adt::errors::ValueError{e.what()}; + } + } +}; + +struct PirTypeMatch { + static adt::Result Call( + axpr::InterpreterBase* interpreter, + const axpr::Value& self_val, + const std::vector& packed_args_val) { + ADT_LET_CONST_REF(self, self_val.template CastTo()); + const auto& type_id = GetTypeAdtTypeId(self); + const auto& type_name = type_id.Match([&](const auto& impl) -> std::string { + using T = typename std::decay_t::type; + return T::name(); + }); + const auto& packed_args = + axpr::CastToPackedArgs(packed_args_val); + const auto& [args, kwargs] = *packed_args; + ADT_CHECK(args->size() == 0) << adt::errors::TypeError{ + std::string() + + "PirType.match() supports keyword arguments only, but " + + std::to_string(args->size()) + " positional arguments were given"}; + auto PatternMatch = + [&](const auto& impl) -> adt::Result> { + using T = typename std::decay_t::type; + return MakePirTypeImpl::GetCallArgs(self_val); + }; + std::string key = type_name; + if (!kwargs->Has(type_name)) { + if (!kwargs->Has("_")) { + return adt::errors::TypeError{std::string() + + "PirType.match() failed. no keyword '" + + type_name + "' or '_' provided"}; + } + key = "_"; + } + ADT_LET_CONST_REF(func, kwargs->Get(key)); + ADT_LET_CONST_REF(type_make_args, type_id.Match(PatternMatch)); + ADT_CHECK(axpr::CallableHelper{}.IsCallable(func)) + << adt::errors::TypeError{ + std::string() + + "the arguments of PirType.match() should be callable"}; + if (key == "_") { + return interpreter->InterpretCall(func, {}); + } else { + return interpreter->InterpretCall(func, type_make_args.vector()); + } + } +}; + +axpr::TypeImpl> GetPirTypeClass() { + static auto cls( + axpr::MakeBuiltinClass("PirType", [&](const auto& Yield) { + Yield("__str__", &PirTypeString); + Yield("get_type_name", &PirTypeGetType::Call); + Yield("convert_to_dtype", &ConvertToDtype::Call); + Yield("match", &PirTypeMatch::Call); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +adt::Result MakePirTypeImplNullType::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + pir::Type type; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplNullType::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplVectorType::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ + std::string() + "pir.t_vec() takes 1 argument, but " + + std::to_string(args.size()) + " were given"}; + ADT_LET_CONST_REF(lst, args.at(0).template CastTo>()) + << adt::errors::TypeError{ + std::string() + + "the argument 1 of pir.t_vec() should be a list (not a " + + axpr::GetTypeName(args.at(0)) + ")"}; + std::vector types; + for (const auto& arg : *lst) { + ADT_LET_CONST_REF(elt, arg.template CastTo()); + types.emplace_back(elt); + } + const pir::Type type{pir::VectorType::get(pir::IrContext::Instance(), types)}; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplVectorType::GetCallArgs( + const axpr::Value& self_val) { + ADT_LET_CONST_REF(pir_type, self_val.template CastTo()); + ADT_CHECK(pir_type.isa()); + const auto& type_list = pir_type.dyn_cast(); + adt::List ret_list{}; + ret_list->reserve(type_list.size()); + for (int i = 0; i < type_list.size(); ++i) { + ret_list->emplace_back(GetPirTypeClass().New(type_list[i])); + } + return adt::List{axpr::Value{ret_list}}; +} + +adt::Result MakePirTypeImplDenseTensorType::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 3); + ADT_LET_CONST_REF(type, args.at(0).template CastTo()); + ADT_LET_CONST_REF(int_list, + args.at(1).template CastTo>()); + std::vector dims; + dims.reserve(int_list->size()); + for (const auto& int_val : *int_list) { + ADT_LET_CONST_REF(elt, int_val.template CastTo()); + dims.emplace_back(elt); + } + ::common::DDim ddim(dims.data(), dims.size()); + ADT_LET_CONST_REF(data_layout_str, args.at(2).template CastTo()); + std::optional<::common::DataLayout> data_layout; + try { + data_layout = ::common::StringToDataLayout(data_layout_str); + } catch (const std::exception&) { + return adt::errors::ValueError{"StringToDataLayout('" + data_layout_str + + "') failed"}; + } + ADT_CHECK(data_layout.has_value()); + const pir::Type dense_tensor_type{pir::DenseTensorType::get( + pir::IrContext::Instance(), type, ddim, data_layout.value())}; + return GetPirTypeClass().New(dense_tensor_type); +} + +adt::Result> MakePirTypeImplDenseTensorType::GetCallArgs( + const axpr::Value& self_val) { + ADT_LET_CONST_REF(pir_type, self_val.template CastTo()); + ADT_CHECK(pir_type.isa()); + const auto& dense_tensor_type = pir_type.dyn_cast(); + // dtype + const auto& dtype = GetPirTypeClass().New(dense_tensor_type.dtype()); + // shape + adt::List dims{}; + dims->reserve(dense_tensor_type.dims().size()); + for (int i = 0; i < dense_tensor_type.dims().size(); ++i) { + int64_t dim = dense_tensor_type.dims().at(i); + dims->emplace_back(dim); + } + // data layout + std::string data_layout_str; + try { + data_layout_str = + ::common::DataLayoutToString(dense_tensor_type.data_layout()); + } catch (const std::exception& e) { + return adt::errors::ValueError{e.what()}; + } + return adt::List{dtype, dims, data_layout_str}; +} + +adt::Result MakePirTypeImplBFloat16Type::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + const pir::Type type{pir::BFloat16Type::get(pir::IrContext::Instance())}; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplBFloat16Type::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplFloat16Type::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + const pir::Type type{pir::Float16Type::get(pir::IrContext::Instance())}; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplFloat16Type::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplFloat32Type::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + const pir::Type type{pir::Float32Type::get(pir::IrContext::Instance())}; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplFloat32Type::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplFloat64Type::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + const pir::Type type{pir::Float64Type::get(pir::IrContext::Instance())}; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplFloat64Type::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplInt8Type::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + const pir::Type type{pir::Int8Type::get(pir::IrContext::Instance())}; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplInt8Type::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplUInt8Type::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + const pir::Type type{pir::UInt8Type::get(pir::IrContext::Instance())}; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplUInt8Type::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplInt16Type::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + const pir::Type type{pir::Int16Type::get(pir::IrContext::Instance())}; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplInt16Type::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplInt32Type::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + const pir::Type type{pir::Int32Type::get(pir::IrContext::Instance())}; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplInt32Type::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplInt64Type::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + const pir::Type type{pir::Int64Type::get(pir::IrContext::Instance())}; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplInt64Type::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplIndexType::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + const pir::Type type{pir::IndexType::get(pir::IrContext::Instance())}; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplIndexType::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplBoolType::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + const pir::Type type{pir::BoolType::get(pir::IrContext::Instance())}; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplBoolType::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplComplex64Type::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + const pir::Type type{pir::Complex64Type::get(pir::IrContext::Instance())}; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplComplex64Type::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplComplex128Type::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + const pir::Type type{pir::Complex128Type::get(pir::IrContext::Instance())}; + return GetPirTypeClass().New(type); +} + +adt::Result> MakePirTypeImplComplex128Type::GetCallArgs( + const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplSelectedRowsType::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 3); + ADT_LET_CONST_REF(type, args.at(0).template CastTo()); + ADT_LET_CONST_REF(int_list, + args.at(1).template CastTo>()); + std::vector dims; + dims.reserve(int_list->size()); + for (const auto& int_val : *int_list) { + ADT_LET_CONST_REF(elt, int_val.template CastTo()); + dims.emplace_back(elt); + } + ::common::DDim ddim(dims.data(), dims.size()); + ADT_LET_CONST_REF(data_layout_str, args.at(2).template CastTo()); + std::optional<::common::DataLayout> data_layout; + try { + data_layout = ::common::StringToDataLayout(data_layout_str); + } catch (const std::exception&) { + return adt::errors::ValueError{"StringToDataLayout('" + data_layout_str + + "') failed"}; + } + ADT_CHECK(data_layout.has_value()); + const pir::Type pir_type{::paddle::dialect::SelectedRowsType::get( + pir::IrContext::Instance(), type, ddim, data_layout.value())}; + return GetPirTypeClass().New(pir_type); +} + +adt::Result> +MakePirTypeImplSelectedRowsType::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF( + pir_type, + self_val.template CastTo<::paddle::dialect::SelectedRowsType>()); + // dtype + const auto& dtype = GetPirTypeClass().New(pir_type.dtype()); + // shape + adt::List dims{}; + dims->reserve(pir_type.dims().size()); + for (int i = 0; i < pir_type.dims().size(); ++i) { + int64_t dim = pir_type.dims().at(i); + dims->emplace_back(dim); + } + // data layout + std::string data_layout_str; + try { + data_layout_str = ::common::DataLayoutToString(pir_type.data_layout()); + } catch (const std::exception& e) { + return adt::errors::ValueError{e.what()}; + } + return adt::List{dtype, dims, data_layout_str}; +} + +adt::Result MakePirTypeImplDenseTensorArrayType::Call( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 3); + ADT_LET_CONST_REF(type, args.at(0).template CastTo()); + ADT_LET_CONST_REF(int_list, + args.at(1).template CastTo>()); + std::vector dims; + dims.reserve(int_list->size()); + for (const auto& int_val : *int_list) { + ADT_LET_CONST_REF(elt, int_val.template CastTo()); + dims.emplace_back(elt); + } + ::common::DDim ddim(dims.data(), dims.size()); + ADT_LET_CONST_REF(data_layout_str, args.at(2).template CastTo()); + std::optional<::common::DataLayout> data_layout; + try { + data_layout = ::common::StringToDataLayout(data_layout_str); + } catch (const std::exception&) { + return adt::errors::ValueError{"StringToDataLayout('" + data_layout_str + + "') failed"}; + } + ADT_CHECK(data_layout.has_value()); + const pir::Type dense_tensor_type{ + ::paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), type, ddim, data_layout.value())}; + return GetPirTypeClass().New(dense_tensor_type); +} + +adt::Result> +MakePirTypeImplDenseTensorArrayType::GetCallArgs(const axpr::Value& self_val) { + ADT_LET_CONST_REF( + dense_tensor_array_type, + self_val.template CastTo<::paddle::dialect::DenseTensorArrayType>()); + // dtype + const auto& dtype = GetPirTypeClass().New(dense_tensor_array_type.dtype()); + // shape + adt::List dims{}; + dims->reserve(dense_tensor_array_type.dims().size()); + for (int i = 0; i < dense_tensor_array_type.dims().size(); ++i) { + int64_t dim = dense_tensor_array_type.dims().at(i); + dims->emplace_back(dim); + } + // data layout + std::string data_layout_str; + try { + data_layout_str = + ::common::DataLayoutToString(dense_tensor_array_type.data_layout()); + } catch (const std::exception& e) { + return adt::errors::ValueError{e.what()}; + } + return adt::List{dtype, dims, data_layout_str}; +} + +adt::Result MakePirTypeImplSparseCooTensorType::Call( + const axpr::Value& self_val, const std::vector& args) { + return adt::errors::NotImplementedError{ + std::string() + ::paddle::dialect::SparseCooTensorType::name() + + "() is not implemented"}; +} + +adt::Result> +MakePirTypeImplSparseCooTensorType::GetCallArgs(const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplSparseCsrTensorType::Call( + const axpr::Value& self_val, const std::vector& args) { + return adt::errors::NotImplementedError{ + std::string() + ::paddle::dialect::SparseCsrTensorType::name() + + "() is not implemented"}; +} + +adt::Result> +MakePirTypeImplSparseCsrTensorType::GetCallArgs(const axpr::Value& self_val) { + return adt::List{}; +} + +adt::Result MakePirTypeImplUnclassifiedType::Call( + const axpr::Value& self_val, const std::vector& args) { + return adt::errors::NotImplementedError{ + std::string() + UnclassifiedType::name() + "() is not implemented"}; +} + +adt::Result> +MakePirTypeImplUnclassifiedType::GetCallArgs(const axpr::Value& self_val) { + return adt::List{}; +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/reified_drr/reified_drr_pass_dump_helper.cc b/paddle/ap/src/reified_drr/reified_drr_pass_dump_helper.cc new file mode 100644 index 00000000000000..37a74c781e2371 --- /dev/null +++ b/paddle/ap/src/reified_drr/reified_drr_pass_dump_helper.cc @@ -0,0 +1,287 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/reified_drr/reified_drr_pass_dump_helper.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/code_module/module_compile_helper.h" +#include "paddle/ap/include/fs/fs.h" +#include "paddle/ap/include/reified_drr/reified_res_ptn_axpr_maker.h" +#include "paddle/ap/include/reified_drr/reified_src_ptn_axpr_maker.h" + +namespace ap::reified_drr { + +struct ReifiedDrrPassDumpHelperImpl { + drr::DrrCtx abstract_drr_ctx_; + DrrNodeAttrToAnfExprHelper* attr2axpr_helper_; + MatchedSrcPtnCtxHelper* matched_src_ptn_ctx_helper_; + std::function>( + const std::string&)> + CodeGenResult4FusedOpName_; + int64_t nice_; + + struct DumpCtx { + std::optional dump_dir; + std::optional reified_drr_pass_class_lambda_anf_expr; + }; + + // Returns reified drr_pass_class lambda + adt::Result Dump() { + DumpCtx dump_ctx; + ADT_LET_CONST_REF(anf_expr, ConvertToModuleAnfExpr(&dump_ctx)); + if (dump_ctx.dump_dir.has_value()) { + const auto& reified_drr_json = anf_expr.DumpToJsonString(); + const auto& reified_drr_json_path = + dump_ctx.dump_dir.value() + "/reified_drr.json"; + ADT_RETURN_IF_ERR( + fs::WriteFileContent(reified_drr_json_path, reified_drr_json)); + } + ADT_CHECK(dump_ctx.reified_drr_pass_class_lambda_anf_expr.has_value()); + return dump_ctx.reified_drr_pass_class_lambda_anf_expr.value(); + } + + adt::Result ConvertToModuleAnfExpr(DumpCtx* dump_ctx) { + axpr::LambdaExprBuilder lmd; + auto GetBody = [&](auto& ctx) -> adt::Result { + ADT_RETURN_IF_ERR(DefineAxprModule(&ctx, dump_ctx)); + return ctx.None(); + }; + return lmd.TryLet(GetBody); + } + + adt::Result DefineAxprModule(axpr::LetContext* ctx, + DumpCtx* dump_ctx) { + ADT_LET_CONST_REF(make_drr_ctx_anf_expr, DefineMakeDrrCtxLambda(dump_ctx)); + ADT_LET_CONST_REF(drr_pass_class_anf_expr, + DefineDrrPassClass(ctx, make_drr_ctx_anf_expr)); + auto GetBody = [&](auto& let_ctx) -> adt::Result { + return DefineDrrPassClass(&let_ctx, make_drr_ctx_anf_expr); + }; + ADT_LET_CONST_REF(lambda, axpr::LambdaExprBuilder{}.TryLambda({}, GetBody)); + dump_ctx->reified_drr_pass_class_lambda_anf_expr = lambda; + ADT_RETURN_IF_ERR( + InsertRegisterReifiedDrrPass(ctx, drr_pass_class_anf_expr)); + return adt::Ok{}; + } + + adt::Result DefineMakeDrrCtxLambda(DumpCtx* dump_ctx) { + auto GetBody = [&](auto& ctx) -> adt::Result { + auto& drr_ctx = ctx.Var("DrrCtx").Call(); + ADT_LET_CONST_REF(src_ptn_func, DefineSourcePatternFunc()); + ADT_LET_CONST_REF(constraint_lambda, DefineConstraintLambda()); + ADT_LET_CONST_REF(res_ptn_func, + DefineOrGetResultPatternFunc( + src_ptn_func, constraint_lambda, dump_ctx)); + drr_ctx.Attr("set_drr_pass_type") + .Call(ctx.String("reified_drr_pass_type")); + drr_ctx.Attr("init_source_pattern").Call(src_ptn_func); + const auto& constraint_func_name = ctx.NewTmpVarName(); + ctx.Var(constraint_func_name) = constraint_lambda; + const auto& constraint_func = + ctx.Var(constraint_func_name).Attr("__function__"); + drr_ctx.Attr("init_constraint_func").Call(constraint_func); + drr_ctx.Attr("init_result_pattern").Call(res_ptn_func); + return drr_ctx; + }; + return axpr::LambdaExprBuilder{}.TryLambda({"self"}, GetBody); + } + + adt::Result DefineSourcePatternFunc() { + ADT_CHECK(abstract_drr_ctx_->source_pattern_ctx.has_value()); + ADT_CHECK(abstract_drr_ctx_->source_pattern_ctx.value() == + matched_src_ptn_ctx_helper_->src_ptn_ctx()); + ReifiedSrcPtnAxprMaker maker{attr2axpr_helper_, + matched_src_ptn_ctx_helper_}; + auto GetBody = [&](auto& ctx) -> adt::Result { + auto* op_pattern_ctx = &ctx.Var("o"); + auto* tensor_pattern_ctx = &ctx.Var("t"); + ADT_RETURN_IF_ERR(maker.GenAnfExprForSrcPtnCtxOps(op_pattern_ctx)); + ADT_RETURN_IF_ERR(maker.GenAnfExprForSrcPtnCtxValues(tensor_pattern_ctx)); + ADT_RETURN_IF_ERR(maker.GenAnfExprForSrcPtnCtxOpValueConnections( + op_pattern_ctx, tensor_pattern_ctx)); + return ctx.None(); + }; + return axpr::LambdaExprBuilder{}.TryLambda({"o", "t"}, GetBody); + } + + adt::Result DefineConstraintLambda() { + axpr::LambdaExprBuilder lmbd; + return lmbd.Lambda({"o", "t", "ir_helper"}, + [&](auto& ctx) { return ctx.Bool(true); }); + } + + adt::Result DefineOrGetResultPatternFunc( + const axpr::AnfExpr& src_ptn_func, + const axpr::AnfExpr& constraint_func, + DumpCtx* dump_ctx) { + std::string src_ptn_func_json = src_ptn_func.DumpToJsonString(); + std::string constraint_func_json = constraint_func.DumpToJsonString(); + std::hash str_hash{}; + std::size_t pattern_hash_value = adt::hash_combine( + str_hash(src_ptn_func_json), str_hash(constraint_func_json)); + ADT_CHECK(abstract_drr_ctx_->pass_name.has_value()); + const std::string relative_dump_dir = + DecodeIntoDirectoryName(abstract_drr_ctx_->pass_name.value()) + "_" + + std::to_string(pattern_hash_value); + ADT_LET_CONST_REF(dump_root_dir, GetDumpDir()); + const std::string& src_ptn_func_json_path = + dump_root_dir + "/" + relative_dump_dir + "/source_pattern_func.json"; + const std::string& constraint_func_json_path = + dump_root_dir + "/" + relative_dump_dir + "/constraint_func.json"; + const std::string& res_ptn_func_json_path = + dump_root_dir + "/" + relative_dump_dir + "/result_pattern_func.json"; + if (fs::FileExists(res_ptn_func_json_path)) { + std::string old_src_ptn_func_json; + ADT_RETURN_IF_ERR( + fs::ReadFileContent(src_ptn_func_json_path, &old_src_ptn_func_json)); + std::string old_constraint_func_json; + ADT_RETURN_IF_ERR(fs::ReadFileContent(constraint_func_json_path, + &old_constraint_func_json)); + ADT_CHECK(old_src_ptn_func_json == src_ptn_func_json); + ADT_CHECK(old_constraint_func_json == constraint_func_json); + std::string res_ptn_func_json; + ADT_RETURN_IF_ERR( + fs::ReadFileContent(res_ptn_func_json_path, &res_ptn_func_json)); + ADT_LET_CONST_REF(res_ptn_func, + axpr::MakeAnfExprFromJsonString(res_ptn_func_json)); + return res_ptn_func; + } else { + dump_ctx->dump_dir = dump_root_dir + "/" + relative_dump_dir; + code_module::ModuleCompileHelper compile_helper{dump_root_dir, + relative_dump_dir}; + using RetT = adt::Result>; + auto CodeGenResult4FusedOpName = + [&](const std::string& op_unique_name) -> RetT { + ADT_LET_CONST_REF(code_gen_result, + CodeGenResult4FusedOpName_(op_unique_name)); + ADT_LET_CONST_REF(package_module, + compile_helper.CompileProjectModuleToPackageModule( + code_gen_result->code_module)); + return code_gen::CodeGenResult{ + package_module, + code_gen_result->kernel_dispatch_func, + code_gen_result->kernel_dispatch_const_data}; + }; + ADT_LET_CONST_REF(res_ptn_func, + DefineResultPatternFunc(CodeGenResult4FusedOpName)); + const auto& res_ptn_func_json = res_ptn_func.DumpToJsonString(); + ADT_RETURN_IF_ERR( + fs::WriteFileContent(src_ptn_func_json_path, src_ptn_func_json)); + ADT_RETURN_IF_ERR(fs::WriteFileContent(constraint_func_json_path, + constraint_func_json)); + ADT_RETURN_IF_ERR( + fs::WriteFileContent(res_ptn_func_json_path, res_ptn_func_json)); + return res_ptn_func; + } + } + + std::string DecodeIntoDirectoryName(std::string str) { + for (int i = 0; i < str.size(); ++i) { + if (str.at(i) >= 'A' && str.at(i) <= 'Z') { + continue; + } + if (str.at(i) >= 'a' && str.at(i) <= 'z') { + continue; + } + if (str.at(i) >= '0' && str.at(i) <= '9') { + continue; + } + str[i] = '_'; + } + return str; + } + + adt::Result GetDumpDir() { + const char* dump_dir = std::getenv("AP_PACKAGE_DUMP_DIR"); + ADT_CHECK(dump_dir != nullptr); + return std::string(dump_dir); + } + + adt::Result DefineResultPatternFunc( + const std::function>( + const std::string&)>& CodeGenResult4FusedOpName) { + ADT_CHECK(abstract_drr_ctx_->result_pattern_ctx.has_value()); + ReifiedResPtnAxprMaker maker(abstract_drr_ctx_->result_pattern_ctx.value(), + CodeGenResult4FusedOpName); + auto GetBody = [&](auto& ctx) -> adt::Result { + auto* op_pattern_ctx = &ctx.Var("o"); + auto* tensor_pattern_ctx = &ctx.Var("t"); + ADT_RETURN_IF_ERR(maker.GenAnfExprForResPtnCtxOps(op_pattern_ctx)); + ADT_RETURN_IF_ERR(maker.GenAnfExprForResPtnCtxOpValueConnections( + op_pattern_ctx, tensor_pattern_ctx)); + return ctx.None(); + }; + return axpr::LambdaExprBuilder{}.TryLambda({"o", "t"}, GetBody); + } + + adt::Result DefineDrrPassClass( + axpr::LetContext* ctx, const axpr::AnfExpr& make_drr_lambda) { + const auto& class_name = ctx->String(std::string("ReifiedDrrPass")); + const auto& superclasses = ctx->Var(axpr::kBuiltinList()).Call(); + const auto& make_drr_func_name = ctx->NewTmpVarName(); + ctx->Var(make_drr_func_name) = make_drr_lambda; + const auto& methods = [&] { + std::vector args{}; + std::map kwargs{ + {"make_drr_ctx", ctx->Var(make_drr_func_name).Attr("__function__")}}; + return ctx->Var("BuiltinSerializableAttrMap").Apply(args, kwargs); + }(); + return ctx->Var("type").Call(class_name, superclasses, methods); + } + + adt::Result InsertRegisterReifiedDrrPass( + axpr::LetContext* ctx, const axpr::AnfExpr& drr_pass_class) { + ADT_LET_CONST_REF(pass_name_val, GetPassName(drr_pass_class)); + const auto& pass_name = ctx->String(pass_name_val); + const auto& nice = ctx->Int64(nice_); + ctx->Var("Registry") + .Attr("classic_drr_pass") + .Call(pass_name, nice, drr_pass_class); + return adt::Ok{}; + } + + adt::Result GetPassName(const axpr::AnfExpr& drr_pass_class) { + ADT_CHECK(abstract_drr_ctx_->pass_name.has_value()); + std::size_t hash_value = GetHashValue(drr_pass_class); + return abstract_drr_ctx_->pass_name.value() + "_reified_" + + std::to_string(hash_value); + } + + std::size_t GetHashValue(const axpr::AnfExpr& anf_expr) { + const auto& serialized = anf_expr.DumpToJsonString(); + return std::hash()(serialized); + } +}; + +bool ReifiedDrrPassDumpHelper::DumpEnabled() { + return std::getenv("AP_PACKAGE_DUMP_DIR") != nullptr; +} + +// Returns reified drr_pass_class lambda +adt::Result ReifiedDrrPassDumpHelper::Dump( + const drr::DrrCtx& abstract_drr_ctx, + DrrNodeAttrToAnfExprHelper* attr2axpr_helper, + MatchedSrcPtnCtxHelper* matched_src_ptn_ctx_helper, + const std::function>( + const std::string&)>& CodeGenResult4FusedOpName, + int64_t nice) const { + ReifiedDrrPassDumpHelperImpl impl{abstract_drr_ctx, + attr2axpr_helper, + matched_src_ptn_ctx_helper, + CodeGenResult4FusedOpName, + nice}; + return impl.Dump(); +} + +} // namespace ap::reified_drr diff --git a/paddle/ap/src/reified_drr/reified_res_ptn_axpr_maker.cc b/paddle/ap/src/reified_drr/reified_res_ptn_axpr_maker.cc new file mode 100644 index 00000000000000..5816dd10b3e3b0 --- /dev/null +++ b/paddle/ap/src/reified_drr/reified_res_ptn_axpr_maker.cc @@ -0,0 +1,275 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/reified_drr/reified_res_ptn_axpr_maker.h" +#include "paddle/ap/include/code_module/module_to_axpr_helper.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h" + +namespace ap::reified_drr { + +namespace { + +using ResPtnCtxOpImpl = + std::variant, drr::PackedIrOp>; + +struct ResPtnCtxOp : public ResPtnCtxOpImpl { + using ResPtnCtxOpImpl::ResPtnCtxOpImpl; + ADT_DEFINE_VARIANT_METHODS(ResPtnCtxOpImpl); + + const std::string& op_unique_name() const { + using RetT = const std::string&; + return Match([&](const auto& impl) -> RetT { return impl->name; }); + } + + static std::optional CastFromDrrNode(const drr::Node& drr_node) { + using RetT = std::optional; + return drr_node.Match( + [&](const drr::NativeIrOp& impl) -> RetT { return impl; }, + [&](const drr::PackedIrOp& impl) -> RetT { return impl; }, + [&](const auto&) -> RetT { return std::nullopt; }); + } + + adt::Result>> GetInputValueNames() + const { + std::vector> ret; + ADT_LET_CONST_REF(reserved_size, num_inputs()); + ret.reserve(reserved_size); + auto CollectValueName = + [&](const auto& op_operand) -> adt::Result { + ADT_LET_CONST_REF(upstreams, op_operand.UpstreamNodes()); + if (upstreams.size() == 0) { + ret.emplace_back(std::nullopt); + } else { + ADT_LET_CONST_REF(input_node, upstreams.Sole()); + ADT_LET_CONST_REF(input, input_node.Get()); + const auto& ir_value = drr::IrValue::OptCastFrom(input); + ADT_CHECK(ir_value.has_value()); + ret.emplace_back(ir_value.value().name()); + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitUpstream(CollectValueName)); + return ret; + } + + adt::Result> GetOutputValueNames() const { + std::vector ret; + ADT_LET_CONST_REF(reserved_size, num_outputs()); + ret.reserve(reserved_size); + auto CollectValueName = [&](const auto& op_result) -> adt::Result { + ADT_LET_CONST_REF(downstreams, op_result.DownstreamNodes()); + ADT_LET_CONST_REF(output_node, downstreams.Sole()); + ADT_LET_CONST_REF(output, output_node.Get()); + const auto& ir_value = drr::IrValue::OptCastFrom(output); + ADT_CHECK(ir_value.has_value()); + ret.emplace_back(ir_value.value().name()); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitDownstream(CollectValueName)); + return ret; + } + + template + adt::Result VisitUpstream(const DoEachT& DoEach) const { + return Match([&](const auto& op) -> adt::Result { + ADT_LET_CONST_REF(upstreams, op->node.UpstreamNodes()); + ADT_RETURN_IF_ERR(upstreams.VisitNodes(DoEach)); + return adt::Ok{}; + }); + } + + template + adt::Result VisitDownstream(const DoEachT& DoEach) const { + return Match([&](const auto& op) -> adt::Result { + ADT_LET_CONST_REF(downstreams, op->node.DownstreamNodes()); + ADT_RETURN_IF_ERR(downstreams.VisitNodes(DoEach)); + return adt::Ok{}; + }); + } + + adt::Result num_inputs() const { + return Match([&](const auto& op) -> adt::Result { + ADT_LET_CONST_REF(upstreams, op->node.UpstreamNodes()); + return upstreams.size(); + }); + } + + adt::Result num_outputs() const { + return Match([&](const auto& op) -> adt::Result { + ADT_LET_CONST_REF(downstreams, op->node.DownstreamNodes()); + return downstreams.size(); + }); + } +}; + +adt::Result GenAnfExprForOpImpl( + axpr::LetVar* op_pattern_ctx, + const std::function>( + const std::string&)>& CodeGenResult4FusedOpName, + const std::string& op_unique_name, + const drr::NativeIrOp& ir_op) { + const auto& op = + op_pattern_ctx->Attr("ap_native_op").Call(ir_op->op_declare->op_name); + op_pattern_ctx->SetAttr(op_unique_name, op); + return adt::Ok{}; +} + +adt::Result GenCodeGenLambda( + const std::function>( + const std::string&)>& CodeGenResult4FusedOpName, + const drr::PackedIrOp& ir_op) { + ADT_LET_CONST_REF(code_gen_result, CodeGenResult4FusedOpName(ir_op->name)); + axpr::LambdaExprBuilder lmd{}; + auto GetBody = [&](auto& ctx) -> adt::Result { + ADT_LET_CONST_REF(code_module_anf_expr, + code_module::ModuleToAxprHelper{}.ConvertModuleToAnfExpr( + &ctx, code_gen_result->code_module)); + const auto& kernel_dispatch_lambda_anf_expr = + axpr::ConvertCoreExprToAnfExpr( + code_gen_result->kernel_dispatch_func->lambda); + const auto& kernel_dispatch_func_name = ctx.NewTmpVarName(); + ctx.Var(kernel_dispatch_func_name) = kernel_dispatch_lambda_anf_expr; + const auto& kernel_dispatch_func_anf_expr = + ctx.Var(kernel_dispatch_func_name).Attr("__function__"); + ADT_LET_CONST_REF(kernel_dispatch_const_data_anf_expr, + axpr::BuiltinSerializableAttrMapToAxprHelper{}.Convert( + &ctx, code_gen_result->kernel_dispatch_const_data)); + std::map kwargs{ + {"module", code_module_anf_expr}, + {"kernel_dispatch_func", kernel_dispatch_func_anf_expr}, + {"kernel_dispatch_const_data", kernel_dispatch_const_data_anf_expr}, + }; + return ctx.Var("CodeGenResult").Apply(std::vector{}, kwargs); + }; + ADT_LET_CONST_REF(ret, lmd.TryLambda({"ctx", "o", "t"}, GetBody)); + return ret; +} + +adt::Result GenAnfExprForOpImpl( + axpr::LetVar* op_pattern_ctx, + const std::function>( + const std::string&)>& CodeGenResult4FusedOpName, + const std::string& op_unique_name, + const drr::PackedIrOp& ir_op) { + ADT_LET_CONST_REF(code_gen_lambda, + GenCodeGenLambda(CodeGenResult4FusedOpName, ir_op)); + auto* ctx = op_pattern_ctx->ctx(); + const std::string& lambda_name = ctx->NewTmpVarName(); + ctx->Var(lambda_name) = code_gen_lambda; + const auto& code_gen_func = ctx->Var(lambda_name); + const auto& op = + op_pattern_ctx->Attr("ap_pattern_fusion_op").Call(code_gen_func); + op_pattern_ctx->SetAttr(op_unique_name, op); + return adt::Ok{}; +} + +template +adt::Result VisitEachResPtnCtxOp( + const drr::ResultPatternCtx& res_ptn_ctx, const DoEachT& DoEach) { + for (const auto& node : res_ptn_ctx->node_arena->nodes()) { + const auto& res_ptn_ctx_op = ResPtnCtxOp::CastFromDrrNode(node); + if (res_ptn_ctx_op.has_value()) { + ADT_RETURN_IF_ERR(DoEach(res_ptn_ctx_op.value())); + } + } + return adt::Ok{}; +} + +} // namespace + +adt::Result ReifiedResPtnAxprMaker::GenAnfExprForResPtnCtxOps( + axpr::LetVar* op_pattern_ctx) { + using Ok = adt::Result; + auto GenAnfExprForOp = [&](const ResPtnCtxOp& op) -> Ok { + const auto& op_unique_name = op.op_unique_name(); + return op.Match([&](const auto& impl) -> Ok { + return GenAnfExprForOpImpl( + op_pattern_ctx, CodeGenResult4FusedOpName_, op_unique_name, impl); + }); + }; + ADT_RETURN_IF_ERR(VisitEachResPtnCtxOp(res_ptn_ctx_, GenAnfExprForOp)); + return adt::Ok{}; +} + +namespace { + +template +adt::Result VisitEachResPtnCtxOpValueConnection( + const drr::ResultPatternCtx& res_ptn_ctx, const DoEachT& DoEach) { + auto DoEachOp = [&](const ResPtnCtxOp& op) -> adt::Result { + ADT_LET_CONST_REF(input_value_names, op.GetInputValueNames()); + ADT_LET_CONST_REF(output_value_names, op.GetOutputValueNames()); + ADT_RETURN_IF_ERR( + DoEach(op.op_unique_name(), input_value_names, output_value_names)); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitEachResPtnCtxOp(res_ptn_ctx, DoEachOp)); + return adt::Ok{}; +} + +} // namespace + +adt::Result +ReifiedResPtnAxprMaker::GenAnfExprForResPtnCtxOpValueConnections( + axpr::LetVar* op_pattern_ctx, axpr::LetVar* tensor_pattern_ctx) { + ADT_CHECK(op_pattern_ctx->ctx() == tensor_pattern_ctx->ctx()); + auto* ctx = op_pattern_ctx->ctx(); + using Ok = adt::Result; + auto GetDrrIrValueAxpr = + [&](const auto& tensor_name) -> adt::Result { + ADT_LET_CONST_REF(ir_value, + drr::OpTensorPatternCtxHelper{}.GetIrValueByUid( + res_ptn_ctx_->tensor_pattern_ctx, tensor_name)); + using RetT = adt::Result; + return ir_value.Match( + [&](const drr::NativeIrValue&) -> RetT { + return tensor_pattern_ctx->Attr(tensor_name); + }, + [&](const drr::PackedIrValue&) -> RetT { + auto* ctx = tensor_pattern_ctx->ctx(); + return ctx->Var(axpr::kBuiltinStarred()) + .Call(tensor_pattern_ctx->Attr(tensor_name)); + }); + }; + auto BuildConnection = + [&](const std::string& op_unique_name, + const std::vector>& in_names, + const std::vector& out_names) -> Ok { + std::vector in_anf_exprs; + in_anf_exprs.reserve(in_names.size()); + for (const auto& opt_name : in_names) { + if (!opt_name.has_value()) { + in_anf_exprs.emplace_back(ctx->None()); + } else { + ADT_LET_CONST_REF(anf_expr, GetDrrIrValueAxpr(opt_name.value())); + in_anf_exprs.emplace_back(anf_expr); + } + } + std::vector out_anf_exprs; + for (const auto& name : out_names) { + ADT_LET_CONST_REF(anf_expr, GetDrrIrValueAxpr(name)); + out_anf_exprs.emplace_back(anf_expr); + } + op_pattern_ctx->Attr(op_unique_name) + .Call(ctx->Var(axpr::kBuiltinList()).Apply(in_anf_exprs), + ctx->Var(axpr::kBuiltinList()).Apply(out_anf_exprs)); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + VisitEachResPtnCtxOpValueConnection(res_ptn_ctx_, BuildConnection)); + return adt::Ok{}; +} + +} // namespace ap::reified_drr diff --git a/paddle/ap/src/reified_drr/reified_src_ptn_axpr_maker.cc b/paddle/ap/src/reified_drr/reified_src_ptn_axpr_maker.cc new file mode 100644 index 00000000000000..e8f9a17294e041 --- /dev/null +++ b/paddle/ap/src/reified_drr/reified_src_ptn_axpr_maker.cc @@ -0,0 +1,332 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/reified_drr/reified_src_ptn_axpr_maker.h" +#include "paddle/ap/include/drr/node.h" +#include "paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h" + +namespace ap::reified_drr { + +namespace { + +using SrcPtnCtxOpImpl = std::variant, + drr::PackedIrOp, + drr::OptPackedIrOp>; + +struct SrcPtnCtxOp : public SrcPtnCtxOpImpl { + using SrcPtnCtxOpImpl::SrcPtnCtxOpImpl; + ADT_DEFINE_VARIANT_METHODS(SrcPtnCtxOpImpl); + + const std::string& op_unique_name() const { + using RetT = const std::string&; + return Match([&](const auto& impl) -> RetT { return impl->name; }); + } + + static std::optional CastFromDrrNode(const drr::Node& drr_node) { + using RetT = std::optional; + return drr_node.Match( + [&](const drr::NativeIrOp& impl) -> RetT { return impl; }, + [&](const drr::PackedIrOp& impl) -> RetT { return impl; }, + [&](const drr::OptPackedIrOp& impl) -> RetT { return impl; }, + [&](const auto&) -> RetT { return std::nullopt; }); + } + + adt::Result>> GetInputValueNames() + const { + std::vector> ret; + ADT_LET_CONST_REF(reserved_size, num_inputs()); + ret.reserve(reserved_size); + auto CollectValueName = + [&](const auto& op_operand) -> adt::Result { + ADT_LET_CONST_REF(upstreams, op_operand.UpstreamNodes()); + if (upstreams.size() == 0) { + ret.emplace_back(std::nullopt); + } else { + ADT_LET_CONST_REF(input_node, upstreams.Sole()); + ADT_LET_CONST_REF(input, input_node.Get()); + const auto& ir_value = drr::IrValue::OptCastFrom(input); + ADT_CHECK(ir_value.has_value()); + ret.emplace_back(ir_value.value().name()); + } + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitUpstream(CollectValueName)); + return ret; + } + + adt::Result> GetOutputValueNames() const { + std::vector ret; + ADT_LET_CONST_REF(reserved_size, num_outputs()); + ret.reserve(reserved_size); + auto CollectValueName = [&](const auto& op_result) -> adt::Result { + ADT_LET_CONST_REF(downstreams, op_result.DownstreamNodes()); + ADT_LET_CONST_REF(output_node, downstreams.Sole()); + ADT_LET_CONST_REF(output, output_node.Get()); + const auto& ir_value = drr::IrValue::OptCastFrom(output); + ADT_CHECK(ir_value.has_value()); + ret.emplace_back(ir_value.value().name()); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitDownstream(CollectValueName)); + return ret; + } + + template + adt::Result VisitUpstream(const DoEachT& DoEach) const { + return Match([&](const auto& op) -> adt::Result { + ADT_LET_CONST_REF(upstreams, op->node.UpstreamNodes()); + ADT_RETURN_IF_ERR(upstreams.VisitNodes(DoEach)); + return adt::Ok{}; + }); + } + + template + adt::Result VisitDownstream(const DoEachT& DoEach) const { + return Match([&](const auto& op) -> adt::Result { + ADT_LET_CONST_REF(downstreams, op->node.DownstreamNodes()); + ADT_RETURN_IF_ERR(downstreams.VisitNodes(DoEach)); + return adt::Ok{}; + }); + } + + adt::Result num_inputs() const { + return Match([&](const auto& op) -> adt::Result { + ADT_LET_CONST_REF(upstreams, op->node.UpstreamNodes()); + return upstreams.size(); + }); + } + + adt::Result num_outputs() const { + return Match([&](const auto& op) -> adt::Result { + ADT_LET_CONST_REF(downstreams, op->node.DownstreamNodes()); + return downstreams.size(); + }); + } +}; + +adt::Result GenAnfExprForOpImpl( + axpr::LetVar* op_pattern_ctx, + DrrNodeAttrToAnfExprHelper* anf_expr_helper, + MatchedSrcPtnCtxHelper* matched_src_ptn_ctx_helper, + const std::string& op_unique_name, + const drr::NativeIrOp& ir_op) { + const auto& op = + op_pattern_ctx->Attr("ap_native_op").Call(ir_op->op_declare->op_name); + op_pattern_ctx->SetAttr(op_unique_name, op); + using Ok = adt::Result; + auto GenSetAttr = [&](const std::string& attr_name, + const axpr::Value& attr_val) -> Ok { + ADT_LET_CONST_REF( + attr_anf_expr, + anf_expr_helper->ConvertAttrToAnfExpr(op_pattern_ctx->ctx(), attr_val)); + op_pattern_ctx->Attr(op_unique_name).SetAttr(attr_name, attr_anf_expr); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + matched_src_ptn_ctx_helper->VisitNativeIrOpAttr(ir_op, GenSetAttr)); + return adt::Ok{}; +} + +adt::Result GenInnerSrcPtnLambda( + DrrNodeAttrToAnfExprHelper* anf_expr_helper, + MatchedSrcPtnCtxHelper* matched_src_ptn_ctx_helper, + const drr::PackedIrOp& ir_op) { + ADT_LET_CONST_REF( + inner_matched_src_ptn_ctx_helper, + matched_src_ptn_ctx_helper->MakeInnerMatchedSrcPtnCtxHelper(ir_op)); + axpr::LambdaExprBuilder lmd{}; + ReifiedSrcPtnAxprMaker maker{anf_expr_helper, + inner_matched_src_ptn_ctx_helper.get()}; + auto GetBody = [&](auto& ctx) -> adt::Result { + auto* op_pattern_ctx = &ctx.Var("o"); + auto* tensor_pattern_ctx = &ctx.Var("t"); + ADT_RETURN_IF_ERR(maker.GenAnfExprForSrcPtnCtxOps(op_pattern_ctx)); + ADT_RETURN_IF_ERR(maker.GenAnfExprForSrcPtnCtxValues(tensor_pattern_ctx)); + ADT_RETURN_IF_ERR(maker.GenAnfExprForSrcPtnCtxOpValueConnections( + op_pattern_ctx, tensor_pattern_ctx)); + return ctx.None(); + }; + ADT_LET_CONST_REF(ret, lmd.TryLambda({"o", "t"}, GetBody)); + return ret; +} + +adt::Result GenAnfExprForOpImpl( + axpr::LetVar* op_pattern_ctx, + DrrNodeAttrToAnfExprHelper* anf_expr_helper, + MatchedSrcPtnCtxHelper* matched_src_ptn_ctx_helper, + const std::string& op_unique_name, + const drr::PackedIrOp& ir_op) { + ADT_LET_CONST_REF( + inner_src_ptn_lambda, + GenInnerSrcPtnLambda(anf_expr_helper, matched_src_ptn_ctx_helper, ir_op)); + auto* ctx = op_pattern_ctx->ctx(); + const std::string& lambda_name = ctx->NewTmpVarName(); + ctx->Var(lambda_name) = inner_src_ptn_lambda; + const auto& inner_src_ptn_func = ctx->Var(lambda_name).Attr("__function__"); + const auto& op = + op_pattern_ctx->Attr("ap_trivial_fusion_op").Call(inner_src_ptn_func); + op_pattern_ctx->SetAttr(op_unique_name, op); + return adt::Ok{}; +} + +adt::Result GenAnfExprForOpImpl( + axpr::LetVar* op_pattern_ctx, + DrrNodeAttrToAnfExprHelper* anf_expr_helper, + MatchedSrcPtnCtxHelper* matched_src_ptn_ctx_helper, + const std::string& op_unique_name, + const drr::OptPackedIrOp& ir_op) { + return adt::errors::NotImplementedError{ + "GenAnfExprForOpImpl(OptPackedIrOp) not implemented"}; +} + +} // namespace + +namespace { + +template +adt::Result VisitEachSrcPtnCtxOp( + const drr::SourcePatternCtx& src_ptn_ctx, const DoEachT& DoEach) { + for (const auto& node : src_ptn_ctx->node_arena->nodes()) { + const auto& src_ptn_ctx_op = SrcPtnCtxOp::CastFromDrrNode(node); + if (src_ptn_ctx_op.has_value()) { + ADT_RETURN_IF_ERR(DoEach(src_ptn_ctx_op.value())); + } + } + return adt::Ok{}; +} + +} // namespace + +adt::Result ReifiedSrcPtnAxprMaker::GenAnfExprForSrcPtnCtxOps( + axpr::LetVar* op_pattern_ctx) { + using Ok = adt::Result; + auto GenAnfExprForOp = [&](const SrcPtnCtxOp& op) -> Ok { + return op.Match([&](const auto& impl) -> Ok { + return GenAnfExprForOpImpl(op_pattern_ctx, + anf_expr_helper_, + matched_src_ptn_ctx_helper_, + op.op_unique_name(), + impl); + }); + }; + const auto& src_ptn_ctx = matched_src_ptn_ctx_helper_->src_ptn_ctx(); + ADT_RETURN_IF_ERR(VisitEachSrcPtnCtxOp(src_ptn_ctx, GenAnfExprForOp)); + return adt::Ok{}; +} + +namespace { + +template +adt::Result VisitEachSrcPtnCtxValue( + const drr::SourcePatternCtx& src_ptn_ctx, const DoEachT& DoEach) { + for (const auto& node : src_ptn_ctx->node_arena->nodes()) { + if (node.template Has>()) { + ADT_RETURN_IF_ERR( + DoEach(node.template Get>())); + } + } + return adt::Ok{}; +} + +} // namespace + +adt::Result ReifiedSrcPtnAxprMaker::GenAnfExprForSrcPtnCtxValues( + axpr::LetVar* tensor_pattern_ctx) { + using Ok = adt::Result; + auto GenAnfExprForValue = + [&](const drr::NativeIrValue& drr_value) -> Ok { + ADT_LET_CONST_REF( + type, matched_src_ptn_ctx_helper_->GetNativeIrValueType(drr_value)); + ADT_LET_CONST_REF(type_anf_expr, + anf_expr_helper_->ConvertTypeToAnfExpr( + tensor_pattern_ctx->ctx(), type)); + tensor_pattern_ctx->Attr(drr_value->name).SetAttr("type", type_anf_expr); + return adt::Ok{}; + }; + const auto& src_ptn_ctx = matched_src_ptn_ctx_helper_->src_ptn_ctx(); + ADT_RETURN_IF_ERR(VisitEachSrcPtnCtxValue(src_ptn_ctx, GenAnfExprForValue)); + return adt::Ok{}; +} + +namespace { + +template +adt::Result VisitEachSrcPtnCtxOpValueConnection( + const drr::SourcePatternCtx& src_ptn_ctx, const DoEachT& DoEach) { + auto DoEachOp = [&](const SrcPtnCtxOp& op) -> adt::Result { + ADT_LET_CONST_REF(input_value_names, op.GetInputValueNames()); + ADT_LET_CONST_REF(output_value_names, op.GetOutputValueNames()); + ADT_RETURN_IF_ERR( + DoEach(op.op_unique_name(), input_value_names, output_value_names)); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(VisitEachSrcPtnCtxOp(src_ptn_ctx, DoEachOp)); + return adt::Ok{}; +} + +} // namespace + +adt::Result +ReifiedSrcPtnAxprMaker::GenAnfExprForSrcPtnCtxOpValueConnections( + axpr::LetVar* op_pattern_ctx, axpr::LetVar* tensor_pattern_ctx) { + ADT_CHECK(op_pattern_ctx->ctx() == tensor_pattern_ctx->ctx()); + auto* ctx = op_pattern_ctx->ctx(); + using Ok = adt::Result; + const auto& src_ptn_ctx = matched_src_ptn_ctx_helper_->src_ptn_ctx(); + auto GetDrrIrValueAxpr = + [&](const auto& tensor_name) -> adt::Result { + ADT_LET_CONST_REF(ir_value, + drr::OpTensorPatternCtxHelper{}.GetIrValueByUid( + src_ptn_ctx->tensor_pattern_ctx, tensor_name)); + using RetT = adt::Result; + return ir_value.Match( + [&](const drr::NativeIrValue&) -> RetT { + return tensor_pattern_ctx->Attr(tensor_name); + }, + [&](const drr::PackedIrValue&) -> RetT { + auto* ctx = tensor_pattern_ctx->ctx(); + return ctx->Var(axpr::kBuiltinStarred()) + .Call(tensor_pattern_ctx->Attr(tensor_name)); + }); + }; + auto BuildConnection = + [&](const std::string& op_unique_name, + const std::vector>& in_names, + const std::vector& out_names) -> Ok { + std::vector in_anf_exprs; + in_anf_exprs.reserve(in_names.size()); + for (const auto& opt_name : in_names) { + if (!opt_name.has_value()) { + in_anf_exprs.emplace_back(ctx->None()); + } else { + ADT_LET_CONST_REF(anf_expr, GetDrrIrValueAxpr(opt_name.value())); + in_anf_exprs.emplace_back(anf_expr); + } + } + std::vector out_anf_exprs; + for (const auto& name : out_names) { + ADT_LET_CONST_REF(anf_expr, GetDrrIrValueAxpr(name)); + out_anf_exprs.emplace_back(anf_expr); + } + op_pattern_ctx->Attr(op_unique_name) + .Call(ctx->Var(axpr::kBuiltinList()).Apply(in_anf_exprs), + ctx->Var(axpr::kBuiltinList()).Apply(out_anf_exprs)); + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR( + VisitEachSrcPtnCtxOpValueConnection(src_ptn_ctx, BuildConnection)); + return adt::Ok{}; +} + +} // namespace ap::reified_drr diff --git a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt index e4e816f8c2b863..14d8c987ae892d 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt @@ -9,7 +9,8 @@ set(cinn_transforms_deps cinn_runtime_dialect op_fusion pir_compiler - json) + json + ap_pass) include_directories(cinn_transforms PRIVATE ${PADDLE_SOURCE_DIR}/third_party/nlohmann_json/include/) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index e31aade9dc82f1..6ad748ad06b1b8 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -27,6 +27,8 @@ #include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" #include "paddle/pir/include/pass/pass_manager.h" +#include "paddle/ap/include/memory/guard.h" +#include "paddle/ap/include/paddle/pass/ap_lower_fusion_op_pass.h" #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/operator/transforms/accuracy_check_pass.h" @@ -62,6 +64,7 @@ #include "paddle/fluid/pir/transforms/general/common_subexpression_elimination_pass.h" #include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h" #include "paddle/fluid/pir/transforms/gpu/fused_gemm_epilogue_pass.h" +#include "paddle/pir/include/core/ir_printer.h" COMMON_DECLARE_bool(cinn_specify_input_dynamic_dim); COMMON_DECLARE_string(cinn_input_dynamic_dim_spec_file); @@ -70,6 +73,7 @@ COMMON_DECLARE_bool(disable_dyshape_in_train); COMMON_DECLARE_bool(enable_cinn_accuracy_check); COMMON_DECLARE_bool(enable_fuse_parallel_matmul_pass); COMMON_DECLARE_bool(enable_fusion_fallback); +COMMON_DECLARE_bool(enable_ap); COMMON_DECLARE_bool(logging_pir_py_code_dump_symbolic_dims); namespace cinn::dialect::ir { @@ -222,20 +226,46 @@ void ApplyCinnLowerPass( VLOG(0) << "Enable CINN Accuracy Check Pass"; pass_manager->AddPass(cinn::dialect::ir::CreateAccuracyCheckPass()); } - if (FLAGS_enable_fusion_fallback) { - VLOG(0) << "Enable Fusion Fallback Pass"; - pass_manager->AddPass(cinn::dialect::ir::CreateFusionFallbackPass()); - } - if (has_dynamic_shape && !force_static_shape) { - pass_manager->AddPass( - cinn::dialect::ir::CreateLowerCinnDyShapeFusionOpPass()); + if (FLAGS_enable_ap) { + ap::memory::Guard guard{}; + if (auto pass = + CreateApLowerFusionOpClassicDrrPass(guard.circlable_ref_list())) { + pass_manager->AddPass(std::move(pass.value())); + pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); + pir::IrPrinter(LOG(ERROR) << "before ApLowerFusionOpClassicDrrPass:\n") + .PrintProgram(program); + pass_manager->Run(program); + pir::IrPrinter(LOG(ERROR) << "after ApLowerFusionOpClassicDrrPass:\n") + .PrintProgram(program); + } + if (auto pass = + CreateApLowerFusionOpAbstractDrrPass(guard.circlable_ref_list())) { + pass_manager->AddPass(std::move(pass.value())); + pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); + pir::IrPrinter(LOG(ERROR) << "before ApLowerFusionOpAbstractDrrPass:\n") + .PrintProgram(program); + pass_manager->Run(program); + pir::IrPrinter(LOG(ERROR) << "after ApLowerFusionOpAbstractDrrPass:\n") + .PrintProgram(program); + pass_manager = CreatePassManager(); + pass_manager->AddPass(cinn::dialect::ir::CreateFusionFallbackPass()); + pass_manager->Run(program); + } } else { - pass_manager->AddPass(cinn::dialect::ir::CreateLowerCinnFusionOpPass()); + if (FLAGS_enable_fusion_fallback) { + VLOG(0) << "Enable Fusion Fallback Pass"; + pass_manager->AddPass(cinn::dialect::ir::CreateFusionFallbackPass()); + } + if (has_dynamic_shape && !force_static_shape) { + pass_manager->AddPass( + cinn::dialect::ir::CreateLowerCinnDyShapeFusionOpPass()); + } else { + pass_manager->AddPass(cinn::dialect::ir::CreateLowerCinnFusionOpPass()); + } + pass_manager->AddPass( + cinn::dialect::ir::CreateSplitGenerateShapeIntoShapeOpsPass()); + pass_manager->Run(program); } - pass_manager->AddPass( - cinn::dialect::ir::CreateSplitGenerateShapeIntoShapeOpsPass()); - - pass_manager->Run(program); } template diff --git a/paddle/common/adt_type_id.h b/paddle/common/adt_type_id.h index 284e795f58f9b6..09aa72ef29b4bf 100644 --- a/paddle/common/adt_type_id.h +++ b/paddle/common/adt_type_id.h @@ -20,7 +20,9 @@ namespace common { template -struct AdtTypeId {}; +struct AdtTypeId { + using type = T; +}; template struct AdtBaseTypeId : public std::variant...> { diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index d46dc3a22ee362..da9f7e9680f612 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -1577,6 +1577,8 @@ PHI_DEFINE_EXPORTED_bool(logging_pir_py_code_dump_symbolic_dims, false, "whether dump symbolic dims into pir py code."); +PHI_DEFINE_EXPORTED_bool(enable_ap, false, "whether enable abstract pass."); + PHI_DEFINE_EXPORTED_bool( pir_interpreter_record_stream_for_gc_cache, false, diff --git a/paddle/phi/CMakeLists.txt b/paddle/phi/CMakeLists.txt index 6a17e55d9bcb94..880338061809ad 100644 --- a/paddle/phi/CMakeLists.txt +++ b/paddle/phi/CMakeLists.txt @@ -50,7 +50,8 @@ set(PHI_DEPS xxhash cblas utf8proc - common) + common + ap_phi) list( APPEND diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 82deb497bbcaec..e90fc02c5c9fbd 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "glog/logging.h" +#include "paddle/ap/include/paddle/phi/ap_infer_meta_helper.h" #include "paddle/common/layout.h" #include "paddle/phi/backends/device_memory_alignment.h" #include "paddle/phi/common/data_type.h" @@ -468,6 +469,24 @@ void AddNInferMeta(const std::vector& x, out->set_dtype(x[0]->dtype()); } +void ApUnaryInferMeta(const std::vector& xs, + int num_outputs, + const std::string& code_module_lambda, + const std::string& infer_meta_lambda, + const std::string& kernel_dispatch_lambda, + const std::string& kernel_dispatch_const_data_lambda, + std::vector outs, + MetaConfig config) { + ApInferMetaHelper helper{}; + const auto& ret = helper.InferMeta(infer_meta_lambda, &xs, &outs); + PADDLE_ENFORCE(!ret.HasError(), + "ApUnaryInferMeta failed. \nTraceback (most recent call " + "last):\n%s\n%s: %s. ", + ret.GetError().CallStackToString(), + ret.GetError().class_name(), + ret.GetError().msg()); +} + // TODO(YuanRisheng) This InferMeta is used in Fluid // and will be deleted in the future. void AddNTensorArrayInferMeta(const std::vector& x, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index efa0ff346bfbea..6a6fb6c4226200 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -140,6 +140,15 @@ void AddNInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void ApUnaryInferMeta(const std::vector& xs, + int num_outputs, + const std::string& code_module_lambda, + const std::string& infer_meta_lambda, + const std::string& kernel_dispatch_lambda, + const std::string& kernel_dispatch_const_data_lambda, + std::vector outs, + MetaConfig config = MetaConfig()); + void AddNTensorArrayInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config); diff --git a/paddle/phi/kernels/gpu/ap_unary.cu b/paddle/phi/kernels/gpu/ap_unary.cu new file mode 100644 index 00000000000000..be8c7fd7985360 --- /dev/null +++ b/paddle/phi/kernels/gpu/ap_unary.cu @@ -0,0 +1,96 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "glog/logging.h" +#include "jitify.hpp" // NOLINT +#include "paddle/common/enforce.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/impl/activation_grad_impl.h" +#include "paddle/phi/kernels/impl/activation_impl.h" + +#include "paddle/ap/include/kernel_dispatch/ap_unary_kernel.h" +#include "paddle/ap/include/paddle/phi/device_ctx.h" + +namespace phi { + +template +void ApUnaryKernel(const Context& dev_ctx, + const std::vector& xs, + int num_outputs, + const std::string& code_module_lambda, + const std::string& infer_meta_lambda, + const std::string& kernel_dispatch_lambda, + const std::string& kernel_dispatch_const_data_lambda, + std::vector outs) { + PADDLE_ENFORCE_GT( + xs.size(), + 0, + phi::errors::InvalidArgument( + "At least 1 input is required. current number out uts: // %d", + xs.size())); + PADDLE_ENFORCE_GT( + outs.size(), + 0, + phi::errors::InvalidArgument( + "num_outputs must be greater than 1. current _outputs: // %d", + outs.size())); + for (auto* out : outs) { + dev_ctx.template Alloc(out); + } + std::shared_ptr impl = + std::make_shared>(&dev_ctx); + ap::kernel_dispatch::DeviceCtx ap_device_ctx{impl}; + const auto& ret = + ap::kernel_dispatch::ApUnaryKernel(ap_device_ctx, + xs, + num_outputs, + code_module_lambda, + infer_meta_lambda, + kernel_dispatch_lambda, + kernel_dispatch_const_data_lambda, + outs); + PADDLE_ENFORCE( + !ret.HasError(), + "ap_kernel failed. \nTraceback (most recent call last):\n%s\n%s: %s. ", + ret.GetError().CallStackToString(), + ret.GetError().class_name(), + ret.GetError().msg()); +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(ap_unary, + GPU, + ALL_LAYOUT, + phi::ApUnaryKernel, + float, + double, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(ap_unary, + GPU, + ALL_LAYOUT, + phi::ApUnaryKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} +#endif diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 95676b91c53dae..fd9deb0839c18a 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -269,6 +269,14 @@ traits : paddle::dialect::ForwardOnlyTrait interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : ap_unary + args : (Tensor[] xs, int num_outputs, str code_module_lambda, str infer_meta_lambda, str rnel_dispatch_lambda, str kernel_dispatch_const_data_lambda) + output : Tensor[](out){num_outputs} + infer_meta : + func : ApUnaryInferMeta + kernel : + func : ap_unary + - op : apply_per_channel_scale args: (Tensor x, Tensor scales) output: Tensor(out) diff --git a/paddle/pir/include/core/op_operand.h b/paddle/pir/include/core/op_operand.h index 4944c31fdb283f..2619b804386bea 100644 --- a/paddle/pir/include/core/op_operand.h +++ b/paddle/pir/include/core/op_operand.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include "paddle/pir/include/core/dll_decl.h" namespace pir { @@ -66,5 +67,17 @@ class IR_API OpOperand { private: detail::OpOperandImpl *impl_{nullptr}; + friend struct std::hash; }; } // namespace pir + +namespace std { + +template <> +struct hash { + std::size_t operator()(const pir::OpOperand &obj) const { + return reinterpret_cast(obj.impl_); + } +}; + +} // namespace std diff --git a/paddle/pir/include/core/op_result.h b/paddle/pir/include/core/op_result.h index 89a7b6664230f1..173001d65ef9bd 100644 --- a/paddle/pir/include/core/op_result.h +++ b/paddle/pir/include/core/op_result.h @@ -40,6 +40,8 @@ class IR_API OpResult : public Value { void *property(const std::string &key) const; void set_property(const std::string &key, const Property &value); + static bool classof(Value value); + static OpResult dyn_cast_from(Value value); private: friend Operation; @@ -47,8 +49,6 @@ class IR_API OpResult : public Value { // Access classof and dyn_cast_from. friend Value; friend struct std::hash; - static bool classof(Value value); - static OpResult dyn_cast_from(Value value); }; } // namespace pir diff --git a/python/setup.py.in b/python/setup.py.in index 4817f82a6e3bde..aa9e27957da9f8 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -1108,7 +1108,8 @@ if '${CMAKE_BUILD_TYPE}' == 'Release': if platform.machine() != 'sw_64' and platform.machine() != 'mips64': for command in commands: if os.system(command) != 0: - raise Exception("patch ${FLUID_CORE_NAME}.%s failed, command: %s" % (ext_name, command)) + pass + # raise Exception("patch ${FLUID_CORE_NAME}.%s failed, command: %s" % (ext_name, command)) ext_modules = [Extension('_foo', ['stub.cc'])] if os.name == 'nt': From 96fea1cec205d50a3c01c94f9460c676fffdb69d Mon Sep 17 00:00:00 2001 From: lixinqi Date: Thu, 13 Feb 2025 11:24:29 +0000 Subject: [PATCH 02/43] remove unused index_expr code --- paddle/ap/CMakeLists.txt | 16 +- .../ap/include/code_gen/builtin_frame_util.h | 6 - .../code_gen/code_gen_ctx_method_class.h | 81 -- .../code_gen/matched_result_pattern_helper.h | 1 - paddle/ap/include/code_gen/op_code_gen_ctx.h | 1 - paddle/ap/include/code_gen/value.h | 2 - .../ap/include/code_gen/value_method_class.h | 3 - paddle/ap/include/graph/graph_descriptor.h | 9 +- .../include/index_expr/builtin_frame_util.h | 44 - .../index_expr/dim_expr_cuda_code_generator.h | 163 --- paddle/ap/include/index_expr/index_closure.h | 103 -- paddle/ap/include/index_expr/index_expr.h | 171 --- .../index_expr/index_expr_builtin_functions.h | 441 ------- .../index_expr/index_expr_interpreter.h | 47 - .../index_expr/index_expr_method_class.h | 46 - .../ap/include/index_expr/index_expr_util.h | 168 --- .../ap/include/index_expr/index_tuple_expr.h | 196 ---- .../index_tuple_expr_cuda_code_generator.h | 97 -- .../index_tuple_expr_method_class.h | 89 -- .../op_index_tuple_expr_signature.h | 44 - ..._index_tuple_expr_signature_method_class.h | 91 -- paddle/ap/include/index_expr/op_signature.h | 78 -- paddle/ap/include/index_expr/slice.h | 47 - .../include/index_expr/slice_method_class.h | 45 - .../index_expr/valid_index_expr_builder.h | 248 ---- paddle/ap/include/index_expr/value.h | 32 - .../include/index_expr/value_method_class.h | 22 - paddle/ap/include/paddle/indexed_ir_graph.h | 60 - .../ap/include/paddle/indexed_ir_graph_util.h | 224 ---- paddle/ap/include/paddle/indexed_ir_node.h | 105 -- .../ap/include/paddle/op_cuda_code_gen_impl.h | 1018 ----------------- paddle/ap/src/index_expr/index_closure.cc | 100 -- .../index_expr_builtin_functions.cc | 22 - paddle/ap/src/index_expr/index_expr_util.cc | 18 - .../index_expr/valid_index_expr_builder.cc | 22 - .../paddle/pass/ap_kernel_define_helper.cc | 3 +- .../paddle/pass/ap_lower_fusion_op_pass.cc | 2 - 37 files changed, 9 insertions(+), 3856 deletions(-) delete mode 100644 paddle/ap/include/index_expr/builtin_frame_util.h delete mode 100644 paddle/ap/include/index_expr/dim_expr_cuda_code_generator.h delete mode 100644 paddle/ap/include/index_expr/index_closure.h delete mode 100644 paddle/ap/include/index_expr/index_expr.h delete mode 100644 paddle/ap/include/index_expr/index_expr_builtin_functions.h delete mode 100644 paddle/ap/include/index_expr/index_expr_interpreter.h delete mode 100644 paddle/ap/include/index_expr/index_expr_method_class.h delete mode 100644 paddle/ap/include/index_expr/index_expr_util.h delete mode 100644 paddle/ap/include/index_expr/index_tuple_expr.h delete mode 100644 paddle/ap/include/index_expr/index_tuple_expr_cuda_code_generator.h delete mode 100644 paddle/ap/include/index_expr/index_tuple_expr_method_class.h delete mode 100644 paddle/ap/include/index_expr/op_index_tuple_expr_signature.h delete mode 100644 paddle/ap/include/index_expr/op_index_tuple_expr_signature_method_class.h delete mode 100644 paddle/ap/include/index_expr/op_signature.h delete mode 100644 paddle/ap/include/index_expr/slice.h delete mode 100644 paddle/ap/include/index_expr/slice_method_class.h delete mode 100644 paddle/ap/include/index_expr/valid_index_expr_builder.h delete mode 100644 paddle/ap/include/index_expr/value.h delete mode 100644 paddle/ap/include/index_expr/value_method_class.h delete mode 100644 paddle/ap/include/paddle/indexed_ir_graph.h delete mode 100644 paddle/ap/include/paddle/indexed_ir_graph_util.h delete mode 100644 paddle/ap/include/paddle/indexed_ir_node.h delete mode 100644 paddle/ap/include/paddle/op_cuda_code_gen_impl.h delete mode 100644 paddle/ap/src/index_expr/index_closure.cc delete mode 100644 paddle/ap/src/index_expr/index_expr_builtin_functions.cc delete mode 100644 paddle/ap/src/index_expr/index_expr_util.cc delete mode 100644 paddle/ap/src/index_expr/valid_index_expr_builder.cc diff --git a/paddle/ap/CMakeLists.txt b/paddle/ap/CMakeLists.txt index 828ad2ee0434b3..a07ebaed5507ed 100644 --- a/paddle/ap/CMakeLists.txt +++ b/paddle/ap/CMakeLists.txt @@ -5,13 +5,6 @@ cc_library( SRCS ${axpr_srcs} DEPS ${axpr_deps}) -file(GLOB_RECURSE index_expr_srcs "src/index_expr/*.cc") -set(index_expr_deps axpr) -cc_library( - index_expr - SRCS ${index_expr_srcs} - DEPS ${index_expr_deps}) - file(GLOB_RECURSE ap_drr_srcs "src/drr/*.cc") set(ap_drr_deps axpr) cc_library( @@ -62,14 +55,7 @@ cc_library( DEPS ${ap_reified_drr_deps}) file(GLOB_RECURSE ap_pass_srcs "src/paddle/pass/*.cc") -set(ap_pass_deps - axpr - ap_pir - index_expr - ap_drr - ap_code_module - ap_code_gen - ap_reified_drr) +set(ap_pass_deps axpr ap_pir ap_drr ap_code_module ap_code_gen ap_reified_drr) cc_library( ap_pass SRCS ${ap_pass_srcs} diff --git a/paddle/ap/include/code_gen/builtin_frame_util.h b/paddle/ap/include/code_gen/builtin_frame_util.h index 12c4a98e49854b..a9fa1dcdd29899 100644 --- a/paddle/ap/include/code_gen/builtin_frame_util.h +++ b/paddle/ap/include/code_gen/builtin_frame_util.h @@ -22,9 +22,6 @@ #include "paddle/ap/include/code_module/func_declare_method_class.h" #include "paddle/ap/include/code_module/package_method_class.h" #include "paddle/ap/include/code_module/project_method_class.h" -#include "paddle/ap/include/index_expr/index_expr_method_class.h" -#include "paddle/ap/include/index_expr/index_tuple_expr_method_class.h" -#include "paddle/ap/include/index_expr/slice_method_class.h" namespace ap::code_gen { @@ -35,9 +32,6 @@ void VisitEachBuiltinFrameClass(const DoEachT& DoEach) { DoEach(code_module::GetFuncDeclareClass()); DoEach(code_module::GetCodeModuleClass()); DoEach(axpr::GetDimExprClass()); - DoEach(index_expr::GetSliceClass()); - DoEach(index_expr::GetIndexExprClass()); - DoEach(index_expr::GetIndexTupleExprClass()); DoEach(GetCodeGenResultClass()); } diff --git a/paddle/ap/include/code_gen/code_gen_ctx_method_class.h b/paddle/ap/include/code_gen/code_gen_ctx_method_class.h index b46930bb8a0cc0..121ccbe1cf4f72 100644 --- a/paddle/ap/include/code_gen/code_gen_ctx_method_class.h +++ b/paddle/ap/include/code_gen/code_gen_ctx_method_class.h @@ -26,7 +26,6 @@ #include "paddle/ap/include/code_gen/op_code_gen_ctx.h" #include "paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id_method_class.h" #include "paddle/ap/include/code_module/code_module.h" -#include "paddle/ap/include/index_expr/index_tuple_expr.h" #include "paddle/ap/include/ir_match/native_or_ref_ir_value.h" #include "paddle/ap/include/registry/registry_singleton.h" @@ -189,84 +188,6 @@ struct CodeGenCtxMethodClass { symbol::ToString(dim_expr)}; return adt::Ok{}; } - - static adt::Result StaticMakeFusionOpCodeGenClass( - const ValueT& self_val, const std::vector& args) { - ADT_LET_CONST_REF(self, self_val.template CastTo()); - return This{}.MakeFusionOpCodeGenClass(self, args); - } - - using NativeOrRefIrValue = ir_match::NativeOrRefIrValue; - - adt::Result MakeFusionOpCodeGenClass( - const Self& self, const std::vector& packed_args_vec) { - const auto& packed_args = axpr::CastToPackedArgs(packed_args_vec); - const auto& [args, kwargs] = *packed_args; - ADT_CHECK(args->size() == 1) << adt::errors::TypeError{ - "'CodeGenCtx.make_fusion_op_code_gen_class' takes 1 positional " - "arguments but " + - std::to_string(args->size()) + " were given."}; - ADT_LET_CONST_REF(ir_op, IrOp::CastFrom(args->at(0))) - << adt::errors::TypeError{ - std::string() + - "the positional argument 1 of " - "'CodeGenCtx.make_fusion_op_code_gen_class' should " - "be able to cast to a NativeIrOp, PackedIrOp or RefIrOp."}; - ADT_LET_CONST_REF(input_index_loop_anchor_flags_lst, - kwargs->template Get>( - "input_index_loop_anchor_flags")) - << adt::errors::TypeError{ - std::string() + - "'CodeGenCtx.input_index_loop_anchor_flags' requires bool list " - "typed " - "keyword argument 'input_index_loop_anchor_flags'."}; - LoopAnchorFlags input_index_loop_anchor_flags; - { - input_index_loop_anchor_flags->reserve( - input_index_loop_anchor_flags_lst->size()); - for (const auto& elt : *input_index_loop_anchor_flags_lst) { - ADT_LET_CONST_REF(mask, elt.template CastTo()) - << adt::errors::TypeError{ - std::string() + - "'CodeGenCtx.input_index_loop_anchor_flags' requires bool " - "list typed " - "keyword argument 'input_index_loop_anchor_flags'."}; - input_index_loop_anchor_flags->emplace_back( - tLoopAnchorFlag{mask}); - } - } - ADT_LET_CONST_REF(output_index_loop_anchor_flags_lst, - kwargs->template Get>( - "output_index_loop_anchor_flags")) - << adt::errors::TypeError{ - std::string() + - "'CodeGenCtx.output_index_loop_anchor_flags' requires bool list " - "typed " - "keyword argument 'output_index_loop_anchor_flags'."}; - LoopAnchorFlags output_index_loop_anchor_flags; - { - output_index_loop_anchor_flags->reserve( - output_index_loop_anchor_flags_lst->size()); - for (const auto& elt : *output_index_loop_anchor_flags_lst) { - ADT_LET_CONST_REF(mask, elt.template CastTo()) - << adt::errors::TypeError{ - std::string() + - "'CodeGenCtx.output_index_loop_anchor_flags' requires bool " - "list typed " - "keyword argument 'output_index_loop_anchor_flags'."}; - output_index_loop_anchor_flags->emplace_back( - tLoopAnchorFlag{mask}); - } - } - - OpCodeGenCtx op_code_gen_ctx{self.shared_ptr(), - input_index_loop_anchor_flags, - output_index_loop_anchor_flags}; - ADT_LET_CONST_REF( - class_attrs, - ConvertFusionOpToClassAttrs(op_code_gen_ctx, ir_op)); - return axpr::TypeImpl>(class_attrs); - } }; template @@ -274,8 +195,6 @@ axpr::TypeImpl> GetCodeGenCtxClass() { using ImplMethods = CodeGenCtxMethodClass; static auto cls( axpr::MakeBuiltinClass("CodeGenCtx", [&](const auto& Define) { - Define("make_fusion_op_code_gen_class", - &ImplMethods::StaticMakeFusionOpCodeGenClass); Define("dim_expr_kernel_arg_id", &ImplMethods::StaticMakeAndCheckDimExprKernelArgId); Define("in_tensor_data_ptr_kernel_arg_id", diff --git a/paddle/ap/include/code_gen/matched_result_pattern_helper.h b/paddle/ap/include/code_gen/matched_result_pattern_helper.h index f2402430ebe895..73487fe676a55b 100644 --- a/paddle/ap/include/code_gen/matched_result_pattern_helper.h +++ b/paddle/ap/include/code_gen/matched_result_pattern_helper.h @@ -25,7 +25,6 @@ #include "paddle/ap/include/drr/result_pattern_helper.h" #include "paddle/ap/include/drr/value.h" #include "paddle/ap/include/graph/graph_helper.h" -#include "paddle/ap/include/index_expr/valid_index_expr_builder.h" #include "paddle/ap/include/ir_match/graph_match_ctx.h" #include "paddle/ap/include/ir_match/graph_matcher.h" #include "paddle/ap/include/ir_match/ir_match_ctx.h" diff --git a/paddle/ap/include/code_gen/op_code_gen_ctx.h b/paddle/ap/include/code_gen/op_code_gen_ctx.h index 8dbbd8e941d2ac..c196218fb2f063 100644 --- a/paddle/ap/include/code_gen/op_code_gen_ctx.h +++ b/paddle/ap/include/code_gen/op_code_gen_ctx.h @@ -18,7 +18,6 @@ #include "paddle/ap/include/axpr/type.h" #include "paddle/ap/include/code_gen/kernel_arg_id.h" #include "paddle/ap/include/code_gen/loop_anchor_flags.h" -#include "paddle/ap/include/index_expr/index_tuple_expr.h" #include "paddle/ap/include/ir_match/native_or_ref_ir_value.h" namespace ap::code_gen { diff --git a/paddle/ap/include/code_gen/value.h b/paddle/ap/include/code_gen/value.h index 835641eb8e9fdf..fa60f39441ff50 100644 --- a/paddle/ap/include/code_gen/value.h +++ b/paddle/ap/include/code_gen/value.h @@ -25,8 +25,6 @@ #include "paddle/ap/include/code_module/adt.h" #include "paddle/ap/include/code_module/code_module.h" #include "paddle/ap/include/code_module/data_type.h" -#include "paddle/ap/include/index_expr/index_expr.h" -#include "paddle/ap/include/index_expr/index_tuple_expr.h" #include "paddle/ap/include/ir_match/op_match_ctx.h" #include "paddle/ap/include/ir_match/tensor_match_ctx.h" diff --git a/paddle/ap/include/code_gen/value_method_class.h b/paddle/ap/include/code_gen/value_method_class.h index 9d6d33cfdf8bba..f67e01392593b7 100644 --- a/paddle/ap/include/code_gen/value_method_class.h +++ b/paddle/ap/include/code_gen/value_method_class.h @@ -20,8 +20,5 @@ #include "paddle/ap/include/code_gen/dim_expr_kernel_arg_id_method_class.h" #include "paddle/ap/include/code_gen/in_tensor_data_ptr_kernel_arg_id_method_class.h" #include "paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id_method_class.h" -#include "paddle/ap/include/index_expr/index_expr_method_class.h" -#include "paddle/ap/include/index_expr/index_tuple_expr_method_class.h" -#include "paddle/ap/include/index_expr/slice_method_class.h" #include "paddle/ap/include/ir_match/op_match_ctx_method_class.h" #include "paddle/ap/include/ir_match/tensor_match_ctx_method_class.h" diff --git a/paddle/ap/include/graph/graph_descriptor.h b/paddle/ap/include/graph/graph_descriptor.h index 1e1f3f6a0f83e4..24e205c2ab61a7 100644 --- a/paddle/ap/include/graph/graph_descriptor.h +++ b/paddle/ap/include/graph/graph_descriptor.h @@ -23,9 +23,12 @@ namespace ap::graph { template -struct GraphDescriptor { - GraphDescriptor(const GraphDescriptor&) = default; - GraphDescriptor(GraphDescriptor&&) = default; +struct GraphDescriptor; + +template +struct GraphDescriptorInterface { + GraphDescriptorInterface(const GraphDescriptorInterface&) = default; + GraphDescriptorInterface(GraphDescriptorInterface&&) = default; template adt::Result VisitUpstreamNodes(const NodeT&, diff --git a/paddle/ap/include/index_expr/builtin_frame_util.h b/paddle/ap/include/index_expr/builtin_frame_util.h deleted file mode 100644 index f4cb67dbb0a989..00000000000000 --- a/paddle/ap/include/index_expr/builtin_frame_util.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ap/include/adt/adt.h" -#include "paddle/ap/include/axpr/attr_map.h" -#include "paddle/ap/include/axpr/builtin_frame_util.h" -#include "paddle/ap/include/index_expr/value_method_class.h" - -namespace ap::index_expr { - -template -void VisitEachBuiltinFrameClass(const DoEachT& DoEach) { - DoEach(axpr::GetDimExprClass()); - DoEach(GetSliceClass()); - DoEach(GetIndexExprClass()); - DoEach(GetInIndexTupleExprSignatureClass()); - DoEach(GetOutIndexTupleExprSignatureClass()); - DoEach(GetOpIndexTupleExprSignatureClass()); -} - -template -ap::axpr::AttrMap MakeBuiltinFrameAttrMap() { - ap::axpr::AttrMap attr_map; - ap::axpr::VisitEachBuiltinFrameAttr( - [&](const std::string& k, const ValueT& v) { attr_map->Set(k, v); }); - VisitEachBuiltinFrameClass( - [&](const auto& cls) { attr_map->Set(cls.Name(), cls); }); - return attr_map; -} - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/dim_expr_cuda_code_generator.h b/paddle/ap/include/index_expr/dim_expr_cuda_code_generator.h deleted file mode 100644 index 67fd61ded974e4..00000000000000 --- a/paddle/ap/include/index_expr/dim_expr_cuda_code_generator.h +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include "paddle/ap/include/common/unique_id.h" -#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" - -namespace ap::index_expr { - -class DimExprCudaCodeGenerator { - public: - using ArgName4DimExprT = - std::function(const symbol::DimExpr&)>; - explicit DimExprCudaCodeGenerator(std::ostringstream* ss, - const ArgName4DimExprT& ArgName4DimExprVal, - const std::string& index_type_name) - : ss_(ss), - ArgName4DimExpr(ArgName4DimExprVal), - index_type_name_(index_type_name) {} - - std::ostringstream& ss() { return *ss_; } - - adt::Result CodeGen(const symbol::DimExpr& dim_expr) { - if (const auto& arg_name = ArgName4DimExpr(dim_expr)) { - return arg_name.value(); - } - return dim_expr.Match([&](const auto& impl) { return CodeGenImpl(impl); }); - } - - private: - adt::Result CodeGenImpl(int64_t c) { return std::to_string(c); } - - adt::Result CodeGenImpl(const std::string& var) { - return adt::errors::TypeError{ - std::string() + "no kernel argument bound to DimExpr '" + var + "'"}; - } - - adt::Result CodeGenImpl( - const symbol::Negative& dim_expr) { - const auto& [operand] = *dim_expr; - ADT_LET_CONST_REF(operand_str, CodeGen(operand)); - return std::string() + "(-" + operand_str + ")"; - } - - adt::Result CodeGenImpl( - const symbol::Reciprocal&) { - return adt::errors::ValueError{ - "reciprocal value should be processed in '*'"}; - } - - adt::Result CodeGenImpl( - const symbol::Add& dim_expr) { - ADT_CHECK(dim_expr.operands->size() > 0); - ADT_LET_CONST_REF(first, CodeGen(dim_expr.operands->at(0))); - std::string ret = first; - for (int i = 1; i < dim_expr.operands->size(); ++i) { - const auto& operand = dim_expr.operands->at(i); - if (operand.Has>()) { - const auto& [negtaive_operand] = - *operand.Get>(); - ADT_LET_CONST_REF(operand_str, CodeGen(negtaive_operand)); - ret += " - " + operand_str; - } else { - ADT_LET_CONST_REF(operand_str, CodeGen(operand)); - ret += " + " + operand_str; - } - } - return std::string() + "(" + ret + ")"; - } - - adt::Result CodeGenImpl( - const symbol::Mul& dim_expr) { - ADT_CHECK(dim_expr.operands->size() > 0); - ADT_LET_CONST_REF(first, CodeGen(dim_expr.operands->at(0))); - std::string ret = first; - for (int i = 1; i < dim_expr.operands->size(); ++i) { - const auto& operand = dim_expr.operands->at(i); - if (operand.Has>()) { - const auto& [negtaive_operand] = - *operand.Get>(); - ADT_LET_CONST_REF(operand_str, CodeGen(negtaive_operand)); - ret += " / " + operand_str; - } else { - ADT_LET_CONST_REF(operand_str, CodeGen(operand)); - ret += " * " + operand_str; - } - } - return std::string() + "(" + ret + ")"; - } - - adt::Result CodeGenImpl( - const symbol::Max& dim_expr) { - ADT_CHECK(dim_expr.operands->size() > 0); - ADT_LET_CONST_REF(first, CodeGen(dim_expr.operands->at(0))); - const std::string& var_name = ap::common::NewUniqueId("_ap_sym"); - ss() << index_type_name_ << " " << var_name << " = " << first << ";\n"; - for (int i = 1; i < dim_expr.operands->size(); ++i) { - const auto& operand = dim_expr.operands->at(i); - const std::string& operand_var_name = ap::common::NewUniqueId("_ap_sym"); - ADT_LET_CONST_REF(operand_str, CodeGen(operand)); - ss() << index_type_name_ << " " << operand_var_name << " = " - << operand_str << ";\n"; - ss() << var_name << " = (" << operand_var_name << " > " << var_name - << " ? " << operand_var_name << " : " << var_name << ");\n"; - } - return var_name; - } - - adt::Result CodeGenImpl( - const symbol::Min& dim_expr) { - ADT_CHECK(dim_expr.operands->size() > 0); - ADT_LET_CONST_REF(first, CodeGen(dim_expr.operands->at(0))); - const std::string& var_name = ap::common::NewUniqueId("_ap_sym"); - ss() << index_type_name_ << " " << var_name << " = " << first << ";\n"; - for (int i = 1; i < dim_expr.operands->size(); ++i) { - const auto& operand = dim_expr.operands->at(i); - const std::string& operand_var_name = ap::common::NewUniqueId("_ap_sym"); - ADT_LET_CONST_REF(operand_str, CodeGen(operand)); - ss() << index_type_name_ << " " << operand_var_name << " = " - << operand_str << ";\n"; - ss() << var_name << " = (" << operand_var_name << " < " << var_name - << " ? " << operand_var_name << " : " << var_name << ");\n"; - } - return var_name; - } - - adt::Result CodeGenImpl( - const symbol::Broadcast& dim_expr) { - ADT_CHECK(dim_expr.operands->size() > 0); - ADT_LET_CONST_REF(first, CodeGen(dim_expr.operands->at(0))); - const std::string& var_name = ap::common::NewUniqueId("_ap_sym"); - ss() << index_type_name_ << " " << var_name << " = " << first << ";\n"; - for (int i = 1; i < dim_expr.operands->size(); ++i) { - const auto& operand = dim_expr.operands->at(i); - const std::string& operand_var_name = ap::common::NewUniqueId("_ap_sym"); - ADT_LET_CONST_REF(operand_str, CodeGen(operand)); - ss() << index_type_name_ << " " << operand_var_name << " = " - << operand_str << ";\n"; - ss() << var_name << " = (" << operand_var_name << " > " << var_name - << " ? " << operand_var_name << " : " << var_name << ");\n"; - } - return var_name; - } - - std::ostringstream* ss_; - ArgName4DimExprT ArgName4DimExpr; - std::string index_type_name_; -}; - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_closure.h b/paddle/ap/include/index_expr/index_closure.h deleted file mode 100644 index 4d1c133b494880..00000000000000 --- a/paddle/ap/include/index_expr/index_closure.h +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include "paddle/ap/include/axpr/core_expr.h" -#include "paddle/ap/include/axpr/error.h" -#include "paddle/ap/include/index_expr/index_expr.h" -#include "paddle/ap/include/index_expr/index_expr_interpreter.h" -#include "paddle/ap/include/index_expr/op_index_tuple_expr_signature.h" -#include "paddle/ap/include/index_expr/value.h" -#include "paddle/ap/include/index_expr/value_method_class.h" - -namespace ap::index_expr { - -using axpr::CoreExpr; -using axpr::Lambda; - -struct IndexClosureData { - const ap::index_expr::Val ctx; - const adt::List inputs_meta; - const adt::List outputs_meta; - const adt::List in_vars; - - bool operator==(const IndexClosureData& other) const { - return other.ctx == this->ctx && other.inputs_meta == this->inputs_meta && - other.outputs_meta == this->outputs_meta && - other.in_vars == this->in_vars; - } -}; - -using Nice2IndexLambdas = - std::map>>; - -struct OrderedOneofIndexClosureImpl { - std::shared_ptr interpreter; - IndexClosureData closure_data; - Nice2IndexLambdas nice2index_lambdas; - - adt::Result operator()( - const IndexTupleExpr&) const; - - bool operator==(const OrderedOneofIndexClosureImpl& other) const { - return other.interpreter == this->interpreter && - other.closure_data == this->closure_data && - other.nice2index_lambdas == this->nice2index_lambdas; - } - - private: - adt::Result CallLambda( - const Lambda& lambda, const IndexTupleExpr&) const; -}; -ADT_DEFINE_RC(OrderedOneofIndexClosure, OrderedOneofIndexClosureImpl); - -using TrackedIndexesTransformImpl = - std::variant; - -struct TrackedIndexesTransform : public TrackedIndexesTransformImpl { - using TrackedIndexesTransformImpl::TrackedIndexesTransformImpl; - ADT_DEFINE_VARIANT_METHODS(TrackedIndexesTransformImpl); -}; - -using OpIndexesTransformSignature = - ap::index_expr::OpSignature; - -struct RecordableIndexClosureImpl { - OpIndexesTransformSignature op_indexes_transform_signature; - - adt::Result operator()( - const IndexTupleExpr&) const; - - bool operator==(const RecordableIndexClosureImpl& other) const { - return other.op_indexes_transform_signature == - this->op_indexes_transform_signature; - } -}; -ADT_DEFINE_RC(RecordableIndexClosure, RecordableIndexClosureImpl); - -using IndexClosureImpl = - std::variant; - -struct IndexClosure : public IndexClosureImpl { - using IndexClosureImpl::IndexClosureImpl; - ADT_DEFINE_VARIANT_METHODS(IndexClosureImpl); - - adt::Result operator()( - const IndexTupleExpr& indexes_expr) const; -}; - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_expr.h b/paddle/ap/include/index_expr/index_expr.h deleted file mode 100644 index 6f1d0c9f0e3d8b..00000000000000 --- a/paddle/ap/include/index_expr/index_expr.h +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include "paddle/ap/include/axpr/adt.h" -#include "paddle/ap/include/axpr/builtin_class_instance.h" -#include "paddle/ap/include/axpr/type.h" -#include "paddle/ap/include/index_expr/slice.h" -#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" - -namespace ap::index_expr { - -struct IndexTupleExpr; - -std::string IndexTupleExprToString(const std::shared_ptr&); - -struct UndefinedIndexExprImpl : public std::monostate { - using std::monostate::monostate; - - std::string ToString() const { return "IndexExpr.Undefined"; } -}; - -ADT_DEFINE_RC(UndefinedIndexExpr, UndefinedIndexExprImpl); - -struct PtrGetItemImpl { - std::string ptr_var_name; - std::shared_ptr indexes_expr; - symbol::DimExpr range; - - bool operator==(const PtrGetItemImpl& other) const { - return (other.ptr_var_name == this->ptr_var_name) && - other.indexes_expr == this->indexes_expr && - other.range == this->range; - } - - std::string ToString() const { - return std::string() + "IndexExpr.PtrGetItem(ptr_var_name=" + ptr_var_name + - ", indexes_expr=" + IndexTupleExprToString(indexes_expr) + - ", range=" + symbol::ToString(range) + ")"; - } -}; - -ADT_DEFINE_RC(PtrGetItem, PtrGetItemImpl); - -struct IndexExprDomainImpl { - symbol::DimExpr range; - - bool operator==(const IndexExprDomainImpl& other) const { - return other.range == this->range; - } - - std::string ToString() const { - return std::string() + "IndexExpr.Domain(" + symbol::ToString(range) + ")"; - } -}; - -ADT_DEFINE_RC(IndexExprDomain, const IndexExprDomainImpl); - -template -struct IndexExprBroadcastMaskImpl { - symbol::DimExpr dim; - Expr index_expr; - - bool operator==(const IndexExprBroadcastMaskImpl& other) const { - return other.dim == this->dim && other.index_expr == this->index_expr; - } - - std::string ToString() const { - return std::string() + - "IndexExpr.BroadcastMask(dim=" + symbol::ToString(dim) + - ", index_expr=" + index_expr.ToString() + ")"; - } -}; - -template -ADT_DEFINE_RC(IndexExprBroadcastMask, const IndexExprBroadcastMaskImpl); - -// IndexExprSlice * IndexExprAffine == IdentityFunc if fields are same. -template -struct IndexExprSliceImpl { - index_expr::Slice slice; - symbol::DimExpr range; - Expr index_expr; - - bool operator==(const IndexExprSliceImpl& other) const { - return (other.slice == this->slice) && (other.range == this->range) && - (other.index_expr == this->index_expr); - } - - std::string ToString() const { - return index_expr.ToString() + ".slice(" + slice->ToString() + - ", range=" + symbol::ToString(range) + ")"; - } -}; - -template -ADT_DEFINE_RC(IndexExprSlice, const IndexExprSliceImpl); - -template -struct IndexExprAffineImpl { - index_expr::Slice slice; - symbol::DimExpr range; - Expr index_expr; - - bool operator==(const IndexExprAffineImpl& other) const { - return (other.slice == this->slice) && (other.range == this->range) && - (other.index_expr == this->index_expr); - } - - std::string ToString() const { - return index_expr.ToString() + ".affine(" + slice->ToString() + - ", range=" + symbol::ToString(range) + ")"; - } -}; - -template -ADT_DEFINE_RC(IndexExprAffine, const IndexExprAffineImpl); - -template -struct DisjointUnionImpl { - T lhs; - T rhs; - - bool operator==(const DisjointUnionImpl& other) const { - return (other.lhs == this->lhs) && (other.rhs == this->rhs); - } - - std::string ToString() const { - return std::string() + "IndexExpr.DisjointUnion(" + lhs.ToString() + ", " + - rhs.ToString() + ")"; - } -}; - -template -ADT_DEFINE_RC(DisjointUnion, const DisjointUnionImpl); - -template -using IndexExprBase = std::variant, - IndexExprSlice, - IndexExprAffine, - DisjointUnion>; - -struct IndexExpr : public IndexExprBase { - using IndexExprBase::IndexExprBase; - ADT_DEFINE_VARIANT_METHODS(IndexExprBase); - - std::string ToString() const { - return Match([](const auto& impl) { return impl->ToString(); }); - } -}; - -template -axpr::TypeImpl> GetIndexExprClass(); - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_expr_builtin_functions.h b/paddle/ap/include/index_expr/index_expr_builtin_functions.h deleted file mode 100644 index ac3dc0db70edc0..00000000000000 --- a/paddle/ap/include/index_expr/index_expr_builtin_functions.h +++ /dev/null @@ -1,441 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ap/include/adt/adt.h" -#include "paddle/ap/include/axpr/builtin_functions.h" -#include "paddle/ap/include/index_expr/index_expr_util.h" -#include "paddle/ap/include/index_expr/valid_index_expr_builder.h" -#include "paddle/ap/include/index_expr/value.h" - -namespace ap::index_expr { - -using adt::Maybe; -using adt::Result; - -template -Result MakePtrGetItem(const Val&, const std::vector& args); - -template -Result MakeIndexExprBroadcastMask(const Val&, - const std::vector& args); - -template -Result MakeSlice(const Val&, const std::vector& args); - -template -Result MakeIndexExprSlice(const Val&, const std::vector& args); - -template -Result MakeIndexExprAffine(const Val&, const std::vector& args); - -template -Result MakeDisjointUnion(const Val&, const std::vector& args); - -template -Result MakeIndexTupleExprPermute(const Val&, const std::vector& args); - -template -Result MakeIndexTupleExprReshape(const Val&, const std::vector& args); - -template -Result MakeIndexTupleExprTransform(axpr::InterpreterBase* interpreter, - const Val& obj, - const std::vector& args); - -template -Result MakeOpIndexTupleExprSignature(const Val&, - const std::vector& args); - -template -Result MakeInIndexTupleExprSignature(const Val&, - const std::vector& args); - -template -Result MakeOutIndexTupleExprSignature(const Val&, - const std::vector& args); - -template -inline Maybe TryGetImplIndexExprValue(const Val& val) { - const auto& ret = val.template TryGet(); - if (ret.template HasOkValue()) { - return ret.GetOkValue(); - } - return adt::Nothing{}; -} - -template -inline adt::Result TryGetDimExpr(const Val& val) { - using RetT = adt::Result; - return val.Match([](int64_t c) -> RetT { return symbol::DimExpr{c}; }, - [](const axpr::BuiltinClassInstance& instance) -> RetT { - return instance.template TryGet(); - }, - [&](const auto&) -> RetT { - return adt::errors::TypeError{ - "TryGetDimExpr() failed. argument 1 should an int or " - "DimExpr (not " + - axpr::GetTypeName(val) + ")"}; - }); -} - -template -inline Maybe TryGetInt64(const Val& val) { - return val.Match( - [](int64_t c) -> Maybe { return c; }, - [](const symbol::DimExpr& dim_expr) -> Maybe { - return dim_expr.Match( - [](const int64_t c) -> Maybe { return c; }, - [](const auto&) -> Maybe { return adt::Nothing{}; }); - }, - [&](const auto&) -> Maybe { return adt::Nothing{}; }); -} - -template -Result MakePtrGetItem(const Val&, const std::vector& args) { - if (args.size() != 3) { - return adt::errors::TypeError{ - std::string("PtrGetItem takes 3 arguments but ") + - std::to_string(args.size()) + "were given."}; - } - const auto& opt_arg1 = TryGetImplIndexExprValue(args.at(1)); - ADT_LET_CONST_REF(dim_expr, TryGetDimExpr(args.at(2))); - return std::visit(::common::Overloaded{ - [&](const std::string& ptr_var_name, - const IndexTupleExpr& indexes_expr) -> Result { - return PtrGetItem{ - ptr_var_name, - std::make_shared(indexes_expr), - dim_expr}; - }, - [&](const auto&, const auto&) -> Result { - return adt::errors::InvalidArgumentError{ - "wrong argument type for PtrGetItem"}; - }}, - args.at(0).variant(), - opt_arg1.variant()); -} - -namespace detail { - -template -Result ConvertResult(const T& result) { - return result.Match([](const auto& impl) -> Result { return impl; }); -} - -} // namespace detail - -template -Result MakeIndexExprBroadcastMask(const Val&, - const std::vector& args) { - if (args.size() != 2) { - return adt::errors::TypeError{ - std::string("IndexExprBroadcastMask takes 2 arguments but ") + - std::to_string(args.size()) + "were given."}; - } - ADT_LET_CONST_REF(dim_expr, TryGetDimExpr(args.at(0))); - const auto& opt_arg1 = TryGetImplIndexExprValue(args.at(1)); - ValidIndexExprBuilder builder{}; - const auto& pattern_match = ::common::Overloaded{ - [&](const IndexExpr& index_expr) -> Result { - return detail::ConvertResult( - builder.BroadcastMask(dim_expr, index_expr)); - }, - [&](const auto&) -> Result { - return adt::errors::InvalidArgumentError{ - "wrong argument type for IndexExprBroadcastMask"}; - }}; - return std::visit(pattern_match, opt_arg1.variant()); -} - -template -Result MakeSlice(const Val&, const std::vector& args) { - if (args.size() != 3) { - return adt::errors::TypeError{std::string("Slice takes 3 arguments but ") + - std::to_string(args.size()) + "were given."}; - } - ADT_LET_CONST_REF(start, TryGetDimExpr(args.at(0))); - ADT_LET_CONST_REF(stop, TryGetDimExpr(args.at(1))); - ADT_LET_CONST_REF(step, TryGetDimExpr(args.at(1))); - return Val{Slice{start, stop, step}}; -} - -template -Result MakeIndexExprSlice(const Val&, const std::vector& args) { - if (args.size() != 3) { - return adt::errors::TypeError{ - std::string("IndexExprSlice takes 3 arguments but ") + - std::to_string(args.size()) + "were given."}; - } - const auto& opt_slice = TryGetImplIndexExprValue(args.at(0)); - ADT_LET_CONST_REF(range, TryGetDimExpr(args.at(1))); - const auto& opt_index_expr = TryGetImplIndexExprValue(args.at(2)); - const auto& pattern_match = ::common::Overloaded{ - [](const Slice& slice, const IndexExpr& expr) -> Result { - ValidIndexExprBuilder builder{}; - return detail::ConvertResult(builder.Slice(slice, range, expr)); - }, - [](const auto&, const auto&) -> Result { - return adt::errors::InvalidArgumentError{ - "wrong argument type for IndexExprSlice"}; - }}; - return std::visit( - pattern_match, opt_slice.variant(), opt_index_expr.variant()); -} - -template -Result MakeIndexExprAffine(const Val&, const std::vector& args) { - if (args.size() != 3) { - return adt::errors::TypeError{ - std::string("IndexExprAffine takes 3 arguments but ") + - std::to_string(args.size()) + "were given."}; - } - const auto& opt_slice = TryGetImplIndexExprValue(args.at(0)); - ADT_LET_CONST_REF(range, TryGetDimExpr(args.at(1))); - const auto& opt_index_expr = TryGetImplIndexExprValue(args.at(2)); - return std::visit( - ::common::Overloaded{ - [](const Slice& slice, const IndexExpr& index_expr) -> Result { - ValidIndexExprBuilder builder{}; - return detail::ConvertResult( - builder.Affine(slice, range, index_expr)); - }, - [](const auto&, const auto&) -> Result { - return adt::errors::InvalidArgumentError{ - "wrong argument type for IndexExprAffine"}; - }}, - opt_slice.variant(), - opt_index_expr.variant()); -} - -template -Result MakeDisjointUnion(const Val&, const std::vector& args) { - const auto& opt_lhs = TryGetImplIndexExprValue(args.at(1)); - const auto& opt_rhs = TryGetImplIndexExprValue(args.at(1)); - return std::visit( - ::common::Overloaded{ - [](const IndexExpr& lhs, const IndexExpr& rhs) -> Result { - ValidIndexExprBuilder builder{}; - return detail::ConvertResult(builder.DisjointUnion(lhs, rhs)); - }, - [](const auto&, const auto&) -> Result { - return adt::errors::InvalidArgumentError{ - "wrong argument type for DisjointUnion"}; - }}, - opt_lhs.variant(), - opt_rhs.variant()); -} - -template -inline Maybe> TryGetInt64List(const Val& val) { - return val.Match( - [](const adt::List& l) -> Maybe> { - adt::List ret; - ret->reserve(l->size()); - for (const auto& elt : *l) { - const auto& opt_int = TryGetInt64(elt); - if (!opt_int.template Has()) { - return adt::Nothing{}; - } - ret->push_back(opt_int.template Get()); - } - return ret; - }, - [](const auto&) -> Maybe> { return adt::Nothing{}; }); -} - -template -inline adt::Result> TryGetDimExprList( - const Val& val) { - using RetT = adt::Result>; - ADT_LET_CONST_REF(l, val.template CastTo>()); - adt::List ret; - ret->reserve(l->size()); - for (const auto& elt : *l) { - ADT_LET_CONST_REF(int_val, TryGetDimExpr(elt)); - ret->push_back(int_val); - } - return ret; -} - -template -Result MakeIndexTupleExprPermute(const Val&, - const std::vector& args) { - if (args.size() != 2) { - return adt::errors::TypeError{ - std::string("IndexTupleExprPermute takes 2 arguments but ") + - std::to_string(args.size()) + "were given."}; - } - const auto& opt_perms = TryGetInt64List(args.at(0)); - const auto& opt_expr = TryGetImplIndexExprValue(args.at(1)); - ValidIndexExprBuilder builder{}; - return std::visit( - ::common::Overloaded{ - [&](const adt::List& perms, - const IndexTupleExpr& expr) -> Result { - return detail::ConvertResult(builder.Permute(perms, expr)); - }, - [](const auto&, const auto&) -> Result { - return adt::errors::InvalidArgumentError{ - "wrong argument type for IndexTupleExprPermute"}; - }}, - opt_perms.variant(), - opt_expr.variant()); -} - -template -Result MakeIndexTupleExprReshape(const Val&, - const std::vector& args) { - if (args.size() != 2) { - return adt::errors::TypeError{ - std::string("IndexTupleExprReshape takes 2 arguments but ") + - std::to_string(args.size()) + "were given."}; - } - ADT_LET_CONST_REF(shape, TryGetDimExprList(args.at(0))); - const auto& opt_expr = TryGetImplIndexExprValue(args.at(1)); - ValidIndexExprBuilder builder{}; - return std::visit( - ::common::Overloaded{ - [&](const IndexTupleExpr& expr) -> Result { - return detail::ConvertResult(builder.Reshape(shape, expr)); - }, - [](const auto&) -> Result { - return adt::errors::InvalidArgumentError{ - "wrong argument type for IndexTupleExprReshape"}; - }}, - opt_expr.variant()); -} - -template -Result MakeIndexTupleExprTransform(axpr::InterpreterBase* interpreter, - const Val&, - const std::vector& args) { - if (args.size() < 1) { - return adt::errors::TypeError{ - "IndexTupleExprTransform takes at least 1 argument but 0 were given."}; - } - const auto& opt_expr = TryGetImplIndexExprValue(args.at(0)); - if (!opt_expr.template Has()) { - return adt::errors::TypeError{ - "The first argument of IndexTupleExprTransform must be a " - "IndexTupleExpr."}; - } - const auto& indexes_expr = opt_expr.template Get(); - const auto& opt_rank = IndexTupleExprGetRank(indexes_expr); - if (!opt_rank.template Has()) { - return adt::errors::TypeError{ - "The first argument of IndexTupleExprTransform must be a ranked " - "IndexTupleExpr."}; - } - const auto& opt_dim_exprs = IndexTupleExprGetRanges(indexes_expr); - if (!opt_dim_exprs.template Has>()) { - return adt::errors::RuntimeError{ - "error occurred where calling IndexTupleExprGetDims"}; - } - const auto& dim_exprs = - opt_dim_exprs.template Get>(); - if (opt_rank.template Get() != args.size() - 1) { - return adt::errors::TypeError{ - "The rank of first argument must equal to number of lambdas."}; - } - adt::List transform_index_exprs; - transform_index_exprs->reserve(args.size() - 1); - for (int i = 1; i < args.size(); ++i) { - const auto& opt_closure = args.at(i).template TryGet>(); - ADT_RETURN_IF_ERR(opt_closure); - const auto& closure = opt_closure.GetOkValue(); - - if (closure->lambda->args.size() != 1) { - return adt::errors::TypeError{std::string("Argument ") + - std::to_string(i) + - " is not a single-argumented closure."}; - } - int idx = i - 1; - IndexExprDomain domain{dim_exprs->at(idx)}; - const auto& ret_lambda_call = - interpreter->InterpretCall(closure, {Val{domain}}); - ADT_RETURN_IF_ERR(ret_lambda_call); - const auto& ret_index_expr = - TryGetImplIndexExprValue(ret_lambda_call.GetOkValue()); - if (!ret_index_expr.template Has()) { - return adt::errors::TypeError{std::string("closure of argument") + - std::to_string(i) + - " does not return a IndexExpr."}; - } - transform_index_exprs->push_back(ret_index_expr.template Get()); - } - ValidIndexExprBuilder builder{}; - ADT_LET_CONST_REF(ret, - detail::ConvertResult(builder.Transform( - transform_index_exprs, indexes_expr))); - return ret; -} - -template -Result MakeOpIndexTupleExprSignature(const Val&, - const std::vector& args) { - if (args.size() != 2) { - return adt::errors::TypeError{ - std::string("OpIndexTupleExprSignature takes 2 arguments but ") + - std::to_string(args.size()) + "were given."}; - } - const auto& in_sig = args.at(0); - const auto& opt_in = in_sig.template TryGet(); - ADT_RETURN_IF_ERR(opt_in); - const auto& in = opt_in.GetOkValue(); - const auto& out_sig = args.at(1); - const auto& opt_out = out_sig.template TryGet(); - ADT_RETURN_IF_ERR(opt_out); - const auto& out = opt_out.GetOkValue(); - return OpIndexTupleExprSignature{in, out}; -} - -template -Result MakeInIndexTupleExprSignature(const Val&, - const std::vector& args) { - adt::List indexes_exprs; - indexes_exprs->reserve(args.size()); - for (const auto& arg : args) { - const auto& maybe_indexes_expr = - TryGetImplIndexExprValue(arg); - if (!maybe_indexes_expr.template Has()) { - return adt::errors::InvalidArgumentError{ - "only arguments of `IndexTupleExpr` type is valid for " - "InIndexTupleExprSignature"}; - } - indexes_exprs->push_back(maybe_indexes_expr.template Get()); - } - return InIndexTupleExprSignature{indexes_exprs}; -} - -template -Result MakeOutIndexTupleExprSignature(const Val&, - const std::vector& args) { - adt::List indexes_exprs; - indexes_exprs->reserve(args.size()); - for (const auto& arg : args) { - const auto& maybe_indexes_expr = - TryGetImplIndexExprValue(arg); - if (!maybe_indexes_expr.template Has()) { - return adt::errors::InvalidArgumentError{ - "only arguments of `IndexTupleExpr` type is valid for " - "OutIndexTupleExprSignature"}; - } - indexes_exprs->push_back(maybe_indexes_expr.template Get()); - } - return OutIndexTupleExprSignature{indexes_exprs}; -} - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_expr_interpreter.h b/paddle/ap/include/index_expr/index_expr_interpreter.h deleted file mode 100644 index f4f0d845c15801..00000000000000 --- a/paddle/ap/include/index_expr/index_expr_interpreter.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/ap/include/axpr/builtin_functions.h" -#include "paddle/ap/include/axpr/core_expr.h" -#include "paddle/ap/include/index_expr/index_expr.h" -#include "paddle/ap/include/index_expr/index_expr_builtin_functions.h" -#include "paddle/ap/include/index_expr/value.h" -#include "paddle/ap/include/index_expr/value_method_class.h" - -namespace ap::index_expr { - -class IndexExprInterpreter { - public: - IndexExprInterpreter(); - IndexExprInterpreter(const IndexExprInterpreter&) = delete; - IndexExprInterpreter(IndexExprInterpreter&&) = delete; - - Result operator()(const axpr::Lambda& lambda, - const std::vector& args) const { - return adt::errors::NotImplementedError{ - "IndexExprInterpreter::operator()(lambda, args)"}; - } - - Result operator()( - const std::unordered_map>& - global_functions, - const axpr::Lambda& lambda, - const std::vector& args) const { - return adt::errors::NotImplementedError{ - "IndexExprInterpreter::operator()(global_functions, lambda, args)"}; - } -}; - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_expr_method_class.h b/paddle/ap/include/index_expr/index_expr_method_class.h deleted file mode 100644 index ee733851f5f65e..00000000000000 --- a/paddle/ap/include/index_expr/index_expr_method_class.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ap/include/axpr/method_class.h" -#include "paddle/ap/include/axpr/naive_class_ops.h" -#include "paddle/ap/include/index_expr/index_expr.h" -#include "paddle/ap/include/index_expr/index_expr_builtin_functions.h" - -namespace ap::index_expr { - -template -struct IndexExprMethodClass { - using This = IndexExprMethodClass; - using Self = IndexExpr; - - static adt::Result ToString(const ValueT& self_val, - const std::vector& args) { - ADT_LET_CONST_REF(self, self_val.template CastTo()); - return self.ToString(); - } -}; - -template -axpr::TypeImpl> GetIndexExprClass() { - using ImplMethods = IndexExprMethodClass; - static auto cls(axpr::MakeBuiltinClass( - "IndexExpr", - [&](const auto& Define) { Define("__str__", &ImplMethods::ToString); })); - using Self = typename ImplMethods::Self; - return axpr::MakeGlobalNaiveClassOps(cls); -} - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_expr_util.h b/paddle/ap/include/index_expr/index_expr_util.h deleted file mode 100644 index e054903e5b2660..00000000000000 --- a/paddle/ap/include/index_expr/index_expr_util.h +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ap/include/adt/adt.h" -#include "paddle/ap/include/axpr/error.h" -#include "paddle/ap/include/index_expr/index_expr.h" -#include "paddle/ap/include/index_expr/index_tuple_expr.h" - -namespace ap::index_expr { - -using adt::Maybe; - -inline Maybe IndexTupleExprGetRank(const IndexTupleExpr& expr) { - return expr.Match( - [](const UndefinedIndexTupleExpr&) -> Maybe { - return adt::Nothing{}; - }, - [](const NothingIndexTupleExpr&) -> Maybe { - return adt::Nothing{}; - }, - [](const IntArrayLikeIndexTupleExpr&) -> Maybe { - return adt::Nothing{}; - }, - [](const IndexTupleExprDomain& domain) -> Maybe { - return domain->ranges->size(); - }, - [](const IndexTupleExprPermute& perm) -> Maybe { - return perm->perms->size(); - }, - [](const IndexTupleExprReshape& reshape) - -> Maybe { return reshape->shape->size(); }, - [](const IndexTupleExprTransform& transform) - -> Maybe { return transform->index_exprs->size(); }); -} - -inline Maybe IndexExprGetRange(const IndexExpr& index_expr) { - return index_expr.Match( - [](const UndefinedIndexExpr&) -> Maybe { - return adt::Nothing{}; - }, - [](const PtrGetItem& ptr_get_item) -> Maybe { - return ptr_get_item->range; - }, - [](const IndexExprDomain& domain) -> Maybe { - return domain->range; - }, - [](const IndexExprBroadcastMask& mask) - -> Maybe { return mask->dim; }, - [](const IndexExprSlice& index_slice) - -> Maybe { return index_slice->range; }, - [](const IndexExprAffine& index_affine) - -> Maybe { return index_affine->range; }, - [](const DisjointUnion& union_expr) -> Maybe { - const auto& opt_lhs_dim_expr = IndexExprGetRange(union_expr->lhs); - const auto& opt_rhs_dim_expr = IndexExprGetRange(union_expr->rhs); - return std::visit( - ::common::Overloaded{ - [](const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) - -> Maybe { return lhs + rhs; }, - [](const auto&, const auto&) -> Maybe { - return adt::Nothing{}; - }}, - opt_lhs_dim_expr.variant(), - opt_rhs_dim_expr.variant()); - }); -} - -inline Maybe IndexExprGetDomain(const IndexExpr& index_expr) { - return index_expr.Match( - [](const UndefinedIndexExpr&) -> Maybe { - return adt::Nothing{}; - }, - [](const PtrGetItem& ptr_get_item) -> Maybe { - return ptr_get_item->range; - }, - [](const IndexExprDomain& domain) -> Maybe { - return domain->range; - }, - [](const IndexExprBroadcastMask& mask) - -> Maybe { - return IndexExprGetDomain(mask->index_expr); - }, - [](const IndexExprSlice& index_slice) - -> Maybe { - return IndexExprGetDomain(index_slice->index_expr); - }, - [](const IndexExprAffine& index_affine) - -> Maybe { - return IndexExprGetDomain(index_affine->index_expr); - }, - [](const DisjointUnion& union_expr) -> Maybe { - const auto& lhs = IndexExprGetDomain(union_expr->lhs); - const auto& rhs = IndexExprGetDomain(union_expr->rhs); - const auto& pattern_match = ::common::Overloaded{ - [&](const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) { - if (lhs == rhs) { - return Maybe{lhs}; - } else { - return Maybe{adt::Nothing{}}; - } - }, - [&](const auto&, const auto&) { - return Maybe{adt::Nothing{}}; - }}; - return std::visit(pattern_match, lhs.variant(), rhs.variant()); - }); -} - -inline Maybe> IndexTupleExprGetRanges( - const IndexTupleExpr& expr) { - return expr.Match( - [](const UndefinedIndexTupleExpr&) -> Maybe> { - return adt::Nothing{}; - }, - [](const NothingIndexTupleExpr&) -> Maybe> { - return adt::Nothing{}; - }, - [](const IntArrayLikeIndexTupleExpr&) - -> Maybe> { return adt::Nothing{}; }, - [](const IndexTupleExprDomain& domain) - -> Maybe> { return domain->ranges; }, - [](const IndexTupleExprPermute& perm) - -> Maybe> { - const auto& opt_origin_dim_exprs = - IndexTupleExprGetRanges(perm->indexes_expr); - if (opt_origin_dim_exprs.Has()) { - return adt::Nothing{}; - } - const auto& origin_dim_exprs = - opt_origin_dim_exprs.Get>(); - adt::List ret; - ret->reserve(perm->perms->size()); - for (const int idx : *perm->perms) { - ret->push_back(origin_dim_exprs->at(idx)); - } - return ret; - }, - [](const IndexTupleExprReshape& reshape) - -> Maybe> { return reshape->shape; }, - [](const IndexTupleExprTransform& transform) - -> Maybe> { - adt::List ret; - ret->reserve(transform->index_exprs->size()); - for (const auto& index_expr : *transform->index_exprs) { - const auto& opt_dim_expr = IndexExprGetRange(index_expr); - if (opt_dim_expr.Has()) { - return adt::Nothing{}; - } - ret->push_back(opt_dim_expr.Get()); - } - return ret; - }); -} - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_tuple_expr.h b/paddle/ap/include/index_expr/index_tuple_expr.h deleted file mode 100644 index 09c90cd9496cf7..00000000000000 --- a/paddle/ap/include/index_expr/index_tuple_expr.h +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include "paddle/ap/include/axpr/adt.h" -#include "paddle/ap/include/axpr/builtin_class_instance.h" -#include "paddle/ap/include/axpr/type.h" -#include "paddle/ap/include/index_expr/index_expr.h" -#include "paddle/ap/include/index_expr/slice.h" -#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" - -namespace ap::index_expr { - -struct UndefinedIndexTupleExprImpl : public std::monostate { - using std::monostate::monostate; - - std::string ToString() const { return "IndexTupleExpr.Undefined"; } - - const char* TypeName() const { return "UndefinedIndexTupleExpr"; } -}; -ADT_DEFINE_RC(UndefinedIndexTupleExpr, UndefinedIndexTupleExprImpl); - -struct NothingIndexTupleExprImpl : public std::monostate { - using std::monostate::monostate; - - std::string ToString() const { return "IndexTupleExpr.Nothing"; } - - const char* TypeName() const { return "NothingIndexTupleExpr"; } -}; -ADT_DEFINE_RC(NothingIndexTupleExpr, NothingIndexTupleExprImpl); - -struct IntArrayLikeIndexTupleExprImpl : public std::monostate { - using std::monostate::monostate; - - std::string ToString() const { return "IndexTupleExpr.IntArrayLike"; } - - const char* TypeName() const { return "IntArrayLikeIndexTupleExpr"; } -}; -ADT_DEFINE_RC(IntArrayLikeIndexTupleExpr, IntArrayLikeIndexTupleExprImpl); - -struct IndexTupleExprDomainImpl { - adt::List ranges; - bool operator==(const IndexTupleExprDomainImpl& other) const { - return other.ranges == this->ranges; - } - - std::string ToString() const { - std::ostringstream ss; - ss << "["; - int i = 0; - for (const auto& elt : *ranges) { - if (i++ > 0) { - ss << ", "; - } - ss << symbol::ToString(elt); - } - ss << "]"; - return std::string() + "IndexTupleExpr.Domain(" + ss.str() + ")"; - } - - const char* TypeName() const { return "IndexTupleExprDomain"; } -}; -ADT_DEFINE_RC(IndexTupleExprDomain, const IndexTupleExprDomainImpl); - -template -struct IndexTupleExprPermuteImpl { - adt::List perms; - Expr indexes_expr; - - bool operator==(const IndexTupleExprPermuteImpl& other) const { - return other.perms == this->perms && - other.indexes_expr == this->indexes_expr; - } - - std::string ToString() const { - std::ostringstream ss; - ss << "["; - int i = 0; - for (int64_t perm : *perms) { - if (i++ > 0) { - ss << ", "; - } - ss << perm; - } - ss << "]"; - return indexes_expr.ToString() + ".permute(" + ss.str() + ")"; - } - - const char* TypeName() const { return "IndexTupleExprPermute"; } -}; - -template -ADT_DEFINE_RC(IndexTupleExprPermute, const IndexTupleExprPermuteImpl); - -template -struct IndexTupleExprReshapeImpl { - adt::List shape; - Expr indexes_expr; - - bool operator==(const IndexTupleExprReshapeImpl& other) const { - return other.shape == this->shape && - other.indexes_expr == this->indexes_expr; - } - - std::string ToString() const { - std::ostringstream ss; - ss << "["; - int i = 0; - for (const auto& elt : *shape) { - if (i++ > 0) { - ss << ", "; - } - ss << symbol::ToString(elt); - } - ss << "]"; - return indexes_expr.ToString() + ".reshape(" + ss.str() + ")"; - } - - const char* TypeName() const { return "IndexTupleExprReshape"; } -}; -template -ADT_DEFINE_RC(IndexTupleExprReshape, const IndexTupleExprReshapeImpl); - -template -struct IndexTupleExprTransformImpl { - adt::List index_exprs; - Expr indexes_expr; - - bool operator==(const IndexTupleExprTransformImpl& other) const { - return other.index_exprs == this->index_exprs && - other.indexes_expr == this->indexes_expr; - } - - std::string ToString() const { - std::ostringstream ss; - ss << "["; - int i = 0; - for (const auto& elt : *index_exprs) { - if (i++ > 0) { - ss << ", "; - } - ss << elt.ToString(); - } - ss << "]"; - return indexes_expr.ToString() + ".transform(" + ss.str() + ")"; - } - - const char* TypeName() const { return "IndexTupleExprTransform"; } -}; -template -ADT_DEFINE_RC(IndexTupleExprTransform, const IndexTupleExprTransformImpl); - -template -using IndexTupleExprBase = std::variant, - IndexTupleExprReshape, - IndexTupleExprTransform>; - -struct IndexTupleExpr : public IndexTupleExprBase { - using IndexTupleExprBase::IndexTupleExprBase; - ADT_DEFINE_VARIANT_METHODS(IndexTupleExprBase); - - const char* TypeName() const { - return Match([](const auto& impl) { return impl->TypeName(); }); - } - - std::string ToString() const { - return Match([](const auto& impl) { return impl->ToString(); }); - } -}; - -inline std::string IndexTupleExprToString( - const std::shared_ptr& indexes_expr) { - return indexes_expr->ToString(); -} - -template -axpr::TypeImpl> GetIndexTupleExprClass(); - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_tuple_expr_cuda_code_generator.h b/paddle/ap/include/index_expr/index_tuple_expr_cuda_code_generator.h deleted file mode 100644 index a12baa6ac45066..00000000000000 --- a/paddle/ap/include/index_expr/index_tuple_expr_cuda_code_generator.h +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include "paddle/ap/include/adt/adt.h" -#include "paddle/ap/include/common/unique_id.h" -#include "paddle/ap/include/index_expr/dim_expr_cuda_code_generator.h" -#include "paddle/ap/include/index_expr/index_tuple_expr.h" - -namespace ap::index_expr { - -class IndexTupleExprCudaCodeGenerator { - public: - using ArgName4DimExprT = - std::function(const symbol::DimExpr&)>; - IndexTupleExprCudaCodeGenerator( - std::ostringstream* ss, - const std::vector& loop_var_names, - const ArgName4DimExprT& ArgName4DimExpr) - : ss_(ss), - loop_var_names_(loop_var_names), - index_type_name_("int64_t"), - dim_expr_code_gen_(ss, ArgName4DimExpr, "int64_t") {} - - std::ostringstream& ss() { return *ss_; } - - adt::Result CodeGen(const IndexTupleExpr& indexes_expr) { - return indexes_expr.Match( - [&](const IndexTupleExprDomain& domain) -> adt::Result { - return CodeGenImpl(domain); - }, - [&](const auto& impl) -> adt::Result { - return adt::errors::NotImplementedError{ - std::string() + - "IndexTupleExprCudaCodeGenerator::CodeGen not support " + - impl->TypeName() + " yet."}; - }); - } - - private: - adt::Result CodeGenImpl(const IndexTupleExprDomain& domain) { - const auto& var_name = NewTmpVarName("_ap_i"); - int i = 0; - auto DoEachPair = [&](const auto& iter, - const auto& stride) -> adt::Result { - if (i++ == 0) { - ADT_CHECK(stride == symbol::DimExpr{int64_t(1)}); - ss() << index_type_name_ << " " << var_name << " = " << iter << ";\n"; - } else { - ADT_LET_CONST_REF(stride_var_name, dim_expr_code_gen_.CodeGen(stride)); - ss() << var_name << " += " << iter << " * " << stride_var_name << ";\n"; - } - return adt::Ok{}; - }; - ADT_RETURN_IF_ERR(VisitEachIterAndStride(domain->ranges, DoEachPair)); - return var_name; - } - - template - adt::Result VisitEachIterAndStride( - const adt::List& ranges, const DoEachPairT& DoEachPair) { - symbol::DimExpr stride{int64_t(1)}; - ADT_CHECK(loop_var_names_.size() == ranges->size()); - for (int i = loop_var_names_.size() - 1; i >= 0; --i) { - const auto& iter_var_name = loop_var_names_.at(i); - const auto& dim = ranges->at(i); - ADT_RETURN_IF_ERR(DoEachPair(iter_var_name, stride)); - stride = stride * dim; - } - return adt::Ok{}; - } - - std::string NewTmpVarName(const std::string& prefix) { - return ap::common::NewUniqueId(prefix); - } - - std::ostringstream* ss_; - std::vector loop_var_names_; - std::string index_type_name_; - DimExprCudaCodeGenerator dim_expr_code_gen_; -}; - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/index_tuple_expr_method_class.h b/paddle/ap/include/index_expr/index_tuple_expr_method_class.h deleted file mode 100644 index e59dc88728bb31..00000000000000 --- a/paddle/ap/include/index_expr/index_tuple_expr_method_class.h +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ap/include/axpr/method_class.h" -#include "paddle/ap/include/axpr/naive_class_ops.h" -#include "paddle/ap/include/index_expr/index_expr_builtin_functions.h" -#include "paddle/ap/include/index_expr/index_tuple_expr.h" - -namespace ap::index_expr { - -template -struct IndexTupleExprMethodClass { - using This = IndexTupleExprMethodClass; - using Self = IndexTupleExpr; - - static adt::Result ToString(const ValueT& self_val, - const std::vector& args) { - ADT_LET_CONST_REF(self, self_val.template CastTo()); - return self.ToString(); - } -}; - -template -struct TypeImplIndexTupleExprMethodClass { - using This = TypeImplIndexTupleExprMethodClass; - using Self = axpr::TypeImpl; - - static adt::Result StaticConstructIndexTupleExprDomain( - const ValueT&, const std::vector& args) { - return This{}.ConstructIndexTupleExprDomain(args); - } - - adt::Result ConstructIndexTupleExprDomain( - const std::vector& args) { - ADT_CHECK(args.size() == 1) << adt::errors::TypeError{ - std::string() + "'IndexTupleExpr.Domain' takes 1 argument but " + - std::to_string(args.size()) + " were given."}; - ADT_LET_CONST_REF(list, args.at(0).template TryGet>()) - << adt::errors::TypeError{std::string() + - "the argument 1 of 'IndexTupleExpr.Domain' " - "should a list of DimExpr."}; - adt::List dim_exprs; - dim_exprs->reserve(list->size()); - for (const auto& arg : *list) { - ADT_LET_CONST_REF(dim_expr, CastToDimExpr(arg)) - << adt::errors::TypeError{std::string() + - "the argument 1 of 'IndexTupleExpr.Domain' " - "should a list of DimExpr."}; - dim_exprs->emplace_back(dim_expr); - } - IndexTupleExpr index_tuple_expr{IndexTupleExprDomain{dim_exprs}}; - axpr::BuiltinClassInstance instance{ - GetIndexTupleExprClass(), index_tuple_expr}; - return instance; - } - - adt::Result CastToDimExpr(const ValueT& val) { - ADT_LET_CONST_REF(dim_expr, TryGetDimExpr(val)); - return dim_expr; - } -}; - -template -axpr::TypeImpl> GetIndexTupleExprClass() { - using TypeImplMethods = TypeImplIndexTupleExprMethodClass; - using ImplMethods = IndexTupleExprMethodClass; - static auto cls( - axpr::MakeBuiltinClass("IndexTupleExpr", [&](const auto& Define) { - Define("Domain", &TypeImplMethods::StaticConstructIndexTupleExprDomain); - Define("__str__", &ImplMethods::ToString); - })); - using Self = typename ImplMethods::Self; - return axpr::MakeGlobalNaiveClassOps(cls); -} - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/op_index_tuple_expr_signature.h b/paddle/ap/include/index_expr/op_index_tuple_expr_signature.h deleted file mode 100644 index 37bc15fa948e85..00000000000000 --- a/paddle/ap/include/index_expr/op_index_tuple_expr_signature.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/ap/include/axpr/builtin_class_instance.h" -#include "paddle/ap/include/axpr/type.h" -#include "paddle/ap/include/index_expr/index_expr.h" -#include "paddle/ap/include/index_expr/index_tuple_expr.h" -#include "paddle/ap/include/index_expr/op_signature.h" - -namespace ap::index_expr { - -using InIndexTupleExprSignature = InputSignature; -using OutIndexTupleExprSignature = OutputSignature; -using OpIndexTupleExprSignature = OpSignature; - -} // namespace ap::index_expr - -namespace ap::axpr { - -template -axpr::TypeImpl> -GetInIndexTupleExprSignatureClass(); - -template -axpr::TypeImpl> -GetOutIndexTupleExprSignatureClass(); - -template -axpr::TypeImpl> -GetOpIndexTupleExprSignatureClass(); - -} // namespace ap::axpr diff --git a/paddle/ap/include/index_expr/op_index_tuple_expr_signature_method_class.h b/paddle/ap/include/index_expr/op_index_tuple_expr_signature_method_class.h deleted file mode 100644 index 7459f2ecab9bc1..00000000000000 --- a/paddle/ap/include/index_expr/op_index_tuple_expr_signature_method_class.h +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ap/include/axpr/method_class.h" -#include "paddle/ap/include/axpr/naive_class_ops.h" -#include "paddle/ap/include/index_expr/op_index_tuple_expr_signature.h" - -namespace ap::index_expr { - -template -struct InIndexTupleExprSignatureMethodClass { - using Self = index_expr::InIndexTupleExprSignature; - - static adt::Result ToString(const ValueT& self_val, - const std::vector& args) { - ADT_LET_CONST_REF(self, self_val.template CastTo()); - return self->ToString(); - } -}; - -template -axpr::TypeImpl> -GetInIndexTupleExprSignatureClass() { - using ImplMethods = InIndexTupleExprSignatureMethodClass; - static auto cls(axpr::MakeBuiltinClass( - "InIndexTupleExprSignature", - [&](const auto& Define) { Define("__str__", &ImplMethods::ToString); })); - using Self = typename ImplMethods::Self; - return axpr::MakeGlobalNaiveClassOps(cls); -} - -template -struct OutIndexTupleExprSignatureMethodClass { - using Self = index_expr::OutIndexTupleExprSignature; - - static adt::Result ToString(const ValueT& self_val, - const std::vector& args) { - ADT_LET_CONST_REF(self, self_val.template CastTo()); - return self->ToString(); - } -}; - -template -axpr::TypeImpl> -GetOutIndexTupleExprSignatureClass() { - using ImplMethods = OutIndexTupleExprSignatureMethodClass; - static auto cls(axpr::MakeBuiltinClass( - "OutIndexTupleExprSignature", - [&](const auto& Define) { Define("__str__", &ImplMethods::ToString); })); - using Self = typename ImplMethods::Self; - return axpr::MakeGlobalNaiveClassOps(cls); -} - -template -struct OpIndexTupleExprSignatureMethodClass { - using Self = index_expr::OpIndexTupleExprSignature; - - static adt::Result ToString(const ValueT& self_val, - const std::vector& args) { - ADT_LET_CONST_REF(self, self_val.template CastTo()); - return self->ToString(); - } -}; - -template -axpr::TypeImpl> -GetOpIndexTupleExprSignatureClass() { - using ImplMethods = OpIndexTupleExprSignatureMethodClass; - static auto cls(axpr::MakeBuiltinClass( - "OpIndexTupleExprSignature", [&](const auto& Define) { - Define("__str__", - &OpIndexTupleExprSignatureMethodClass::ToString); - })); - using Self = typename ImplMethods::Self; - return axpr::MakeGlobalNaiveClassOps(cls); -} - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/op_signature.h b/paddle/ap/include/index_expr/op_signature.h deleted file mode 100644 index 0c73ec7006e82c..00000000000000 --- a/paddle/ap/include/index_expr/op_signature.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/ap/include/index_expr/index_expr.h" - -namespace ap::index_expr { - -template -struct InputSignature { - adt::List descriptors; - - std::string ToString() const { - std::ostringstream ss; - int i = 0; - for (const auto& elt : *descriptors) { - if (i++ > 0) { - ss << ", "; - } - ss << elt.ToString(); - } - return std::string() + "InputSignature(" + ss.str() + ")"; - } - - bool operator==(const InputSignature& other) const { - return other.descriptors == this->descriptors; - } -}; - -template -struct OutputSignature { - adt::List descriptors; - - std::string ToString() const { - std::ostringstream ss; - int i = 0; - for (const auto& elt : *descriptors) { - if (i++ > 0) { - ss << ", "; - } - ss << elt.ToString(); - } - return std::string() + "OutputSignature(" + ss.str() + ")"; - } - - bool operator==(const OutputSignature& other) const { - return other.descriptors == this->descriptors; - } -}; - -template -struct OpSignature { - InputSignature in_signature; - OutputSignature out_signature; - - std::string ToString() const { - return std::string() + "OpSignature(" + in_signature.ToString() + ", " + - out_signature.ToString() + ")"; - } - - bool operator==(const OpSignature& other) const { - return other.in_signature == this->in_signature && - other.out_signature == this->out_signature; - } -}; - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/slice.h b/paddle/ap/include/index_expr/slice.h deleted file mode 100644 index 4e4af6456efbcc..00000000000000 --- a/paddle/ap/include/index_expr/slice.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include "paddle/ap/include/axpr/adt.h" -#include "paddle/ap/include/axpr/builtin_class_instance.h" -#include "paddle/ap/include/axpr/type.h" -#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" - -namespace ap::index_expr { - -struct SliceImpl { - symbol::DimExpr start; - symbol::DimExpr stop; - symbol::DimExpr step; - - bool operator==(const SliceImpl& other) const { - return (other.start == this->start) && (other.stop == this->stop) && - (other.step == this->step); - } - - std::string ToString() const { - return std::string() + "Slice(start=" + symbol::ToString(start) + - ", stop=" + symbol::ToString(stop) + - ", step=" + symbol::ToString(step) + ")"; - } -}; - -ADT_DEFINE_RC(Slice, const SliceImpl); - -template -axpr::TypeImpl> GetSliceClass(); - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/slice_method_class.h b/paddle/ap/include/index_expr/slice_method_class.h deleted file mode 100644 index 5398457afcbba7..00000000000000 --- a/paddle/ap/include/index_expr/slice_method_class.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ap/include/axpr/method_class.h" -#include "paddle/ap/include/axpr/naive_class_ops.h" -#include "paddle/ap/include/index_expr/slice.h" - -namespace ap::index_expr { - -template -struct SliceMethodClass { - using This = SliceMethodClass; - using Self = Slice; - - static adt::Result ToString(const ValueT& self_val, - const std::vector& args) { - ADT_LET_CONST_REF(self, self_val.template CastTo()); - return self->ToString(); - } -}; - -template -axpr::TypeImpl> GetSliceClass() { - using ImplMethods = SliceMethodClass; - static auto cls(axpr::MakeBuiltinClass( - "Slice", - [&](const auto& Define) { Define("__str__", &ImplMethods::ToString); })); - using Self = typename ImplMethods::Self; - return axpr::MakeGlobalNaiveClassOps(cls); -} - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/valid_index_expr_builder.h b/paddle/ap/include/index_expr/valid_index_expr_builder.h deleted file mode 100644 index a2847ec515d9aa..00000000000000 --- a/paddle/ap/include/index_expr/valid_index_expr_builder.h +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ap/include/adt/adt.h" -#include "paddle/ap/include/axpr/error.h" -#include "paddle/ap/include/index_expr/index_expr.h" -#include "paddle/ap/include/index_expr/index_expr_util.h" -#include "paddle/ap/include/index_expr/index_tuple_expr.h" -#include "paddle/ap/include/index_expr/slice.h" - -namespace ap::index_expr { - -using adt::Result; - -class ValidIndexExprBuilder { - public: - ValidIndexExprBuilder() {} - ValidIndexExprBuilder(const ValidIndexExprBuilder&) = delete; - ValidIndexExprBuilder(ValidIndexExprBuilder&&) = delete; - - Result BroadcastMask(const symbol::DimExpr& dim_expr, - const IndexExpr& index_expr) { - return IndexExprBroadcastMask{dim_expr, index_expr}; - } - - Result Slice(const ap::index_expr::Slice& slice, - const symbol::DimExpr& range, - const IndexExpr& index_expr) { - return IndexExprSlice{slice, range, index_expr}; - } - - Result Affine(const ap::index_expr::Slice& slice, - const symbol::DimExpr& range, - const IndexExpr& index_expr) { - return IndexExprAffine{slice, range, index_expr}; - } - - Result DisjointUnion(const IndexExpr& lhs_index_expr, - const IndexExpr& rhs_index_expr) { - const auto& lhs_domain = IndexExprGetDomain(lhs_index_expr); - const auto& rhs_domain = IndexExprGetDomain(rhs_index_expr); - const auto& pattern_match = ::common::Overloaded{ - [](const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) { - return lhs == rhs; - }, - [](const auto&, const auto&) { return false; }}; - const bool do_equal = - std::visit(pattern_match, lhs_domain.variant(), rhs_domain.variant()); - if (!do_equal) { - return adt::errors::TypeError{ - "domain of `lhs_index_expr' does not equal to domain of " - "`rhs_index_expr'"}; - } - return index_expr::DisjointUnion{lhs_index_expr, rhs_index_expr}; - } - - Result Permute(const adt::List& perms, - const IndexTupleExpr& indexes_expr) { - if (!IsValidPerm(perms)) { - return adt::errors::InvalidArgumentError{"argument `perms` is not valid"}; - } - const auto& rank = IndexTupleExprGetRank(indexes_expr); - if (!rank.Has()) { - return adt::errors::InvalidArgumentError{ - "wrong indexes_expr argument for IndexTupleExprPermute"}; - } - if (rank.Get() != perms->size()) { - return adt::errors::InvalidArgumentError{std::string( - "the rank of perms does not equal to the rank of " - "indexes_expr. rank(perm): " + - std::to_string(perms->size()) + - ", rank(indexes_expr): " + std::to_string(rank.Get()))}; - } - return IndexTupleExprPermute{perms, indexes_expr}; - } - - Result Reshape(const adt::List& shape, - const IndexTupleExpr& indexes_expr) { - if (ContainsNegative(shape)) { - return adt::errors::InvalidArgumentError{ - "dims in argument `shape` have negative integer"}; - } - const auto& opt_ranges = IndexTupleExprGetRanges(indexes_expr); - if (opt_ranges.Has()) { - return adt::errors::InvalidArgumentError{ - "argument `indexes_expr` is not a ranked IndexTupleExpr"}; - } - if (!ProductEqual(shape, opt_ranges.Get>())) { - return adt::errors::InvalidArgumentError{ - "product of argument `shape` does not equal to elements of " - "`indexes_expr`"}; - } - return IndexTupleExprReshape{shape, indexes_expr}; - } - - Result Transform( - const adt::List& transform_index_exprs, - const IndexTupleExpr& indexes_expr) { - const auto& opt_rank = IndexTupleExprGetRank(indexes_expr); - if (!opt_rank.Has()) { - return adt::errors::TypeError{ - "The first argument of IndexTupleExprTransform must be a ranked " - "IndexTupleExpr."}; - } - const auto& opt_ranges = IndexTupleExprGetRanges(indexes_expr); - if (!opt_ranges.Has>()) { - return adt::errors::RuntimeError{ - "error occurred where calling IndexTupleExprGetDims"}; - } - const auto& ranges = opt_ranges.Get>(); - if (opt_rank.Get() != transform_index_exprs->size()) { - return adt::errors::TypeError{ - "The rank of first argument must equal to number of lambdas."}; - } - adt::List domains{}; - domains->reserve(transform_index_exprs->size()); - for (const auto& index_expr : *transform_index_exprs) { - const auto& domain = IndexExprGetDomain(index_expr); - if (!domain.Has()) { - return adt::errors::TypeError{ - "one of transform_index_exprs has no demain."}; - } - domains->emplace_back(domain.Get()); - } - if (ranges != domains) { - return adt::errors::TypeError{ - "domain of `transform_index_exprs' does not equal to range of " - "`indexes_expr'."}; - } - return IndexTupleExprTransform{transform_index_exprs, - indexes_expr}; - } - - // outer(inner(x)) == (outer . inner)(x) - Result Compose(const IndexTupleExpr& outer, - const IndexTupleExpr& inner) { - return outer.Match( - [&](const UndefinedIndexTupleExpr& impl) -> Result { - return impl; - }, - [&](const NothingIndexTupleExpr& impl) -> Result { - return impl; - }, - [&](const IntArrayLikeIndexTupleExpr& impl) -> Result { - return impl; - }, - [&](const IndexTupleExprDomain& domain) -> Result { - const auto& ranges = IndexTupleExprGetRanges(inner); - if (ranges.Has()) { - return adt::errors::TypeError{"`inner_indexes_expr' has no range."}; - } - if (ranges.Get>() != domain->ranges) { - return adt::errors::TypeError{ - "the domain of `outer_indexes_expr' does not equal to the " - "range " - "of `inner_indexes_expr'."}; - } - return inner; - }, - [&](const IndexTupleExprPermute& perm) - -> Result { - const auto& composed_inner = Compose(perm->indexes_expr, inner); - if (composed_inner.HasError()) { - return composed_inner.GetError(); - } - return Permute(perm->perms, composed_inner.Get()); - }, - [&](const IndexTupleExprReshape& reshape) - -> Result { - const auto& composed_inner = Compose(reshape->indexes_expr, inner); - if (composed_inner.HasError()) { - return composed_inner.GetError(); - } - return Reshape(reshape->shape, composed_inner.Get()); - }, - [&](const IndexTupleExprTransform& transform) - -> Result { - const auto& composed_inner = Compose(transform->indexes_expr, inner); - if (composed_inner.HasError()) { - return composed_inner.GetError(); - } - return Transform(transform->index_exprs, - composed_inner.Get()); - }); - } - - private: - template - bool IsValidPerm(const PermsT& perms) { - std::vector idx2touched(perms->size(), false); - for (int64_t perm : *perms) { - if (perm < 0) { - return false; - } - if (perm >= perms->size()) { - return false; - } - idx2touched[perm] = true; - } - for (bool touched : idx2touched) { - if (!touched) { - return false; - } - } - return true; - } - - template - bool ContainsNegative(const ShapeT& shape) { - for (const auto& dim : *shape) { - if (!dim.template Has()) { - continue; - } - if (dim.template Get() < 0) { - return true; - } - } - return false; - } - - template - symbol::DimExpr Product(const DimExprsT& dim_exprs) { - symbol::DimExpr ret_expr{1}; - for (const auto& dim : *dim_exprs) { - ret_expr = ret_expr * dim; - } - return ret_expr; - } - - bool ProductEqual(const auto& lhs, const auto& rhs) { - return Product(lhs) == Product(rhs); - } -}; - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/value.h b/paddle/ap/include/index_expr/value.h deleted file mode 100644 index dbf6061c5e50d2..00000000000000 --- a/paddle/ap/include/index_expr/value.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/ap/include/axpr/builtin_functions.h" -#include "paddle/ap/include/axpr/core_expr.h" -#include "paddle/ap/include/axpr/dim_expr.h" -#include "paddle/ap/include/axpr/value.h" -#include "paddle/ap/include/index_expr/index_expr.h" -#include "paddle/ap/include/index_expr/index_tuple_expr.h" -#include "paddle/ap/include/index_expr/op_index_tuple_expr_signature.h" -#include "paddle/pir/include/core/attribute.h" -#include "paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h" - -namespace ap::index_expr { - -using axpr::Value; - -using Val = Value; - -} // namespace ap::index_expr diff --git a/paddle/ap/include/index_expr/value_method_class.h b/paddle/ap/include/index_expr/value_method_class.h deleted file mode 100644 index dcdf380fc36411..00000000000000 --- a/paddle/ap/include/index_expr/value_method_class.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ap/include/axpr/dim_expr_method_class.h" -#include "paddle/ap/include/axpr/value_method_class.h" -#include "paddle/ap/include/index_expr/index_expr_method_class.h" -#include "paddle/ap/include/index_expr/index_tuple_expr_method_class.h" -#include "paddle/ap/include/index_expr/op_index_tuple_expr_signature_method_class.h" -#include "paddle/ap/include/index_expr/slice_method_class.h" diff --git a/paddle/ap/include/paddle/indexed_ir_graph.h b/paddle/ap/include/paddle/indexed_ir_graph.h deleted file mode 100644 index 9b628fdba7e9ae..00000000000000 --- a/paddle/ap/include/paddle/indexed_ir_graph.h +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ap/include/adt/adt.h" -#include "paddle/ap/include/graph/node_arena.h" -#include "paddle/ap/include/paddle/indexed_ir_node.h" - -namespace ap::paddle { - -using IndexedIrNodeArena = graph::NodeArena; -using IndexedIrNodeArenaPtr = std::shared_ptr; - -struct PureElementwiseIndexedIrGraphImpl { - IndexedIrNodeArenaPtr node_arena; - // free values in fusion op block. - std::vector> inputs; - // yield values in fusion op block. - std::vector> yield_op_inputs; - // output values of fusion op. - std::vector outputs; - - std::unordered_map> - pir_value2indexed_ir_value; - - adt::Result> GetIndexedIrValue( - pir::Value value) const { - const auto& iter = this->pir_value2indexed_ir_value.find(value); - ADT_CHECK(iter != this->pir_value2indexed_ir_value.end()); - return iter->second; - } - - bool operator==(const PureElementwiseIndexedIrGraphImpl& other) const { - return this == &other; - } -}; - -ADT_DEFINE_RC(PureElementwiseIndexedIrGraph, PureElementwiseIndexedIrGraphImpl); - -using IndexedIrGraphImpl = std::variant; - -struct IndexedIrGraph : public IndexedIrGraphImpl { - using IndexedIrGraphImpl::IndexedIrGraphImpl; - - ADT_DEFINE_VARIANT_METHODS(IndexedIrGraphImpl); -}; - -} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/indexed_ir_graph_util.h b/paddle/ap/include/paddle/indexed_ir_graph_util.h deleted file mode 100644 index 6d9efc91172746..00000000000000 --- a/paddle/ap/include/paddle/indexed_ir_graph_util.h +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ap/include/index_expr/index_tuple_expr.h" -#include "paddle/ap/include/paddle/indexed_ir_graph.h" -#include "paddle/ap/include/paddle/pir_node.h" -#include "paddle/ap/include/paddle/pir_util.h" -#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" - -namespace ap::paddle { - -adt::Result CreatePureElementwiseIndexedIrGraph( - const PackedIrOp& pir_op, const index_expr::IndexTupleExpr& indexes_expr); - -adt::Result GetPackedIrOpInputsOutputs( - const PackedIrOp& pir_op, - std::vector* inputs, - std::vector* yield_op_inputs, - std::vector* outputs); - -namespace detail { - -struct CreatePureElementwiseIndexedIrGraphHelper { - struct Ctx { - std::unordered_map> value2node; - - bool Has(pir::Value value) const { - return this->value2node.find(value) != this->value2node.end(); - } - - void Insert(pir::Value value, const IndexedIrValue& node) { - this->value2node[value] = node; - } - - adt::Result> Get(pir::Value value) const { - const auto& iter = this->value2node.find(value); - ADT_CHECK(iter != this->value2node.end()); - return iter->second; - } - }; - - adt::Result Create( - const PackedIrOp& pir_op, - const index_expr::IndexTupleExpr& indexes_expr) { - Ctx ctx{}; - ADT_LET_CONST_REF(node_arena_ptr, - CreateNodeArena(&ctx, pir_op, indexes_expr)); - return CreatePureElementwiseIndexedIrGraph(node_arena_ptr, ctx, pir_op); - } - - adt::Result - CreatePureElementwiseIndexedIrGraph(const IndexedIrNodeArenaPtr& node_arena, - const Ctx& ctx, - const PackedIrOp& pir_op) { - std::vector inputs; - std::vector yield_op_inputs; - std::vector outputs; - ADT_RETURN_IF_ERR(GetPackedIrOpInputsOutputs( - pir_op, &inputs, &yield_op_inputs, &outputs)); - ADT_LET_CONST_REF(input_nodes, GetIndexedIrValues(ctx, inputs)); - ADT_LET_CONST_REF(yield_op_input_nodes, - GetIndexedIrValues(ctx, yield_op_inputs)); - return PureElementwiseIndexedIrGraph{ - node_arena, input_nodes, yield_op_input_nodes, outputs, ctx.value2node}; - } - - adt::Result>> GetIndexedIrValues( - const Ctx& ctx, const std::vector values) { - std::vector> ret; - ret.reserve(values.size()); - for (const auto& value : values) { - ADT_LET_CONST_REF(ir_value, ctx.Get(value)); - ret.emplace_back(ir_value); - } - return ret; - } - - adt::Result CreateNodeArena( - Ctx* ctx, - const PackedIrOp& pir_op, - const index_expr::IndexTupleExpr& indexes_expr) { - auto node_arena = std::make_shared(); - for (auto& op : *pir_op.fusion_op.block()) { - if (op.isa()) { - continue; - } - const auto& ir_op = InsertOpNode(node_arena, &op); - InsertValueNodes(ctx, node_arena, &op, indexes_expr); - ADT_RETURN_IF_ERR(ConnectOpOperandEdges(ctx, ir_op)); - ADT_RETURN_IF_ERR(ConnectOpResultEdges(ctx, ir_op)); - } - return node_arena; - } - - adt::Result ConnectOpResultEdges( - Ctx* ctx, const IndexedIrOp& ir_op) { - auto* op = ir_op->op; - for (int i = 0; i < op->num_results(); ++i) { - ADT_LET_CONST_REF(ir_value, ctx->Get(op->result(i))); - ADT_RETURN_IF_ERR( - ir_op->node.ConnectTo(ir_value->node, - graph::IndexedTag{}, - graph::UnindexedTag{})); - } - return adt::Ok{}; - } - - adt::Result ConnectOpOperandEdges( - Ctx* ctx, const IndexedIrOp& ir_op) { - auto* op = ir_op->op; - for (int i = 0; i < op->num_operands(); ++i) { - ADT_LET_CONST_REF(ir_value, ctx->Get(op->operand_source(i))); - ADT_RETURN_IF_ERR( - ir_value->node.ConnectTo(ir_op->node, - graph::UnindexedTag{}, - graph::IndexedTag{})); - } - return adt::Ok{}; - } - - void InsertValueNodes(Ctx* ctx, - const IndexedIrNodeArenaPtr& node_arena, - pir::Operation* op, - const index_expr::IndexTupleExpr& indexes_expr) { - VisitInOutValue(op, [&](pir::Value value) { - const auto& ir_node = node_arena->New([&](const auto& node) { - return IndexedIrValue{node, value, indexes_expr}; - }); - const auto& ir_value = - ir_node.template Get>(); - if (!ctx->Has(value)) { - ctx->Insert(value, ir_value); - } - }); - } - - template - void VisitInOutValue(pir::Operation* op, const DoEachT& DoEach) { - for (int i = 0; i < op->num_operands(); ++i) { - DoEach(op->operand_source(i)); - } - for (int i = 0; i < op->num_results(); ++i) { - DoEach(op->result(i)); - } - } - - IndexedIrOp InsertOpNode( - const IndexedIrNodeArenaPtr& node_arena, pir::Operation* op) { - const auto& ir_node = node_arena->New([&](const auto& node) { - return IndexedIrOp{node, op}; - }); - return ir_node.template Get>(); - } -}; - -} // namespace detail - -inline adt::Result CreatePureElementwiseIndexedIrGraph( - const PackedIrOp& pir_op, const index_expr::IndexTupleExpr& indexes_expr) { - detail::CreatePureElementwiseIndexedIrGraphHelper helper{}; - ADT_LET_CONST_REF(ir_graph, helper.Create(pir_op, indexes_expr)); - return ir_graph; -} - -namespace detail { - -struct GetPackedIrOpInputsOutputsHelper { - adt::Result GetPackedIrOpInputsOutputs( - const PackedIrOp& pir_op, - std::vector* inputs, - std::vector* yield_op_inputs, - std::vector* outputs) { - *inputs = ap::paddle::GetUsedExternalValue(*pir_op.fusion_op); - outputs->clear(); - outputs->reserve(pir_op.fusion_op->num_results()); - for (int i = 0; i < pir_op.fusion_op->num_results(); ++i) { - outputs->emplace_back(pir_op.fusion_op->result(i)); - } - bool found_yield_op = false; - for (const auto& op : *pir_op.fusion_op.block()) { - yield_op_inputs->clear(); - yield_op_inputs->reserve(op.num_operands()); - if (op.isa()) { - for (int i = 0; i < op.num_operands(); ++i) { - yield_op_inputs->emplace_back(op.operand_source(i)); - } - found_yield_op = true; - } - } - if (found_yield_op) { - return adt::Ok{}; - } else { - return adt::errors::ValueError{ - "No yield op have been found in fusion op block."}; - } - } -}; - -} // namespace detail - -inline adt::Result GetPackedIrOpInputsOutputs( - const PackedIrOp& pir_op, - std::vector* inputs, - std::vector* yield_op_inputs, - std::vector* outputs) { - detail::GetPackedIrOpInputsOutputsHelper helper{}; - return helper.GetPackedIrOpInputsOutputs( - pir_op, inputs, yield_op_inputs, outputs); -} - -} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/indexed_ir_node.h b/paddle/ap/include/paddle/indexed_ir_node.h deleted file mode 100644 index 5b33b2a0b52e75..00000000000000 --- a/paddle/ap/include/paddle/indexed_ir_node.h +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ap/include/adt/adt.h" -#include "paddle/ap/include/graph/node.h" -#include "paddle/ap/include/index_expr/index_tuple_expr.h" -#include "paddle/pir/include/core/operation.h" -#include "paddle/pir/include/core/value.h" - -namespace ap::paddle { - -inline void ConvertToAZaz09_(std::string* str) { - for (int i = 0; i < str->size(); ++i) { - char* ch = &str->at(i); - if (*ch >= 'a' && *ch <= 'z') { - continue; - } - if (*ch >= 'A' && *ch <= 'Z') { - continue; - } - if (*ch >= '0' && *ch <= '9') { - continue; - } - *ch = '_'; - } -} - -inline std::string GetOpUniqueName(const pir::Operation* op) { - std::string op_name = op->name(); - ConvertToAZaz09_(&op_name); - return op_name + "_" + std::to_string(op->id()); -} - -template -struct IndexedIrValueImpl { - graph::Node node; - pir::Value value; - index_expr::IndexTupleExpr indexes_expr; - - std::string GetUniqueNameInsideNodeArena() const { - if (value.defining_op()) { - return GetOpUniqueName(value.defining_op()) + "_out_" + - std::to_string(node.node_id().value()); - } else { - return std::string() + "non_op_out_" + - std::to_string(node.node_id().value()); - } - } - - bool operator==(const IndexedIrValueImpl& other) const { - return this->value == other.value && - this->indexes_expr == other.indexes_expr; - } -}; - -template -ADT_DEFINE_RC(IndexedIrValue, IndexedIrValueImpl); - -template -struct IndexedIrOpImpl { - graph::Node node; - pir::Operation* op; - - std::string GetUniqueNameInsideNodeArena() const { - return GetOpUniqueName(op) + +"_" + std::to_string(node.node_id().value()); - } - - bool operator==(const IndexedIrOpImpl& other) const { - return this->op == other.op; - } -}; - -template -ADT_DEFINE_RC(IndexedIrOp, IndexedIrOpImpl); - -template -using IndexedIrNodeImpl = - std::variant, IndexedIrOp>; - -struct IndexedIrNode : public IndexedIrNodeImpl { - using IndexedIrNodeImpl::IndexedIrNodeImpl; - - ADT_DEFINE_VARIANT_METHODS(IndexedIrNodeImpl); - - const graph::Node& node() const { - return Match([](const auto& impl) -> const graph::Node& { - return impl->node; - }); - } -}; - -} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/op_cuda_code_gen_impl.h b/paddle/ap/include/paddle/op_cuda_code_gen_impl.h deleted file mode 100644 index 10e535dcd20a08..00000000000000 --- a/paddle/ap/include/paddle/op_cuda_code_gen_impl.h +++ /dev/null @@ -1,1018 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include "paddle/ap/include/axpr/anf_expr_util.h" -#include "paddle/ap/include/axpr/data_type_util.h" -#include "paddle/ap/include/axpr/lambda_expr_builder.h" -#include "paddle/ap/include/axpr/pointer_type_util.h" -#include "paddle/ap/include/code_gen/code_gen_ctx.h" -#include "paddle/ap/include/code_gen/dim_expr_kernel_arg_id.h" -#include "paddle/ap/include/code_gen/op_code_gen_ctx.h" -#include "paddle/ap/include/code_gen/op_cuda_gen_impl.h" -#include "paddle/ap/include/drr/node.h" -#include "paddle/ap/include/drr/topo_kind.h" -#include "paddle/ap/include/drr/value.h" -#include "paddle/ap/include/graph/node.h" -#include "paddle/ap/include/index_expr/index_tuple_expr_cuda_code_generator.h" -#include "paddle/ap/include/ir_match/native_or_ref_ir_value.h" -#include "paddle/ap/include/paddle/indexed_ir_graph_util.h" -#include "paddle/ap/include/paddle/pir_graph_descriptor.h" -#include "paddle/ap/include/paddle/pir_node.h" -#include "paddle/ap/include/registry/registry.h" -#include "paddle/ap/include/registry/registry_mgr.h" -#include "paddle/ap/include/registry/registry_singleton.h" -#include "paddle/ap/include/registry/value.h" -#include "paddle/fluid/pir/dialect/operator/utils/utils.h" -#include "paddle/pir/include/core/builtin_type.h" - -namespace ap::paddle { - -struct OpCudaCodeGenImpl { - using BirNode = PirNode; - using OpCodeGenCtx = code_gen::OpCodeGenCtx; - using IrOp = code_gen::IrOp; - - using DrrValue = drr::Value; - using DrrNode = drr::Node; - using DrrGraphNode = graph::Node; - using DrrPackedIrOp = drr::PackedIrOp; - using DrrOptPackedIrOp = drr::OptPackedIrOp; - using DrrOptPackedIrOpOperand = drr::OptPackedIrOpOperand; - using DrrOptPackedIrOpResult = drr::OptPackedIrOpResult; - - using DrrTrivialFusionIrOpImpl = - std::variant; - struct DrrTrivialFusionIrOp : public DrrTrivialFusionIrOpImpl { - using DrrTrivialFusionIrOpImpl::DrrTrivialFusionIrOpImpl; - ADT_DEFINE_VARIANT_METHODS(DrrTrivialFusionIrOpImpl); - - DrrGraphNode node() const { - return Match([](const auto& impl) { return impl->node; }); - } - }; - - using DrrNativeIrValue = drr::NativeIrValue; - using DrrPackedIrValue = drr::PackedIrValue; - using IndexTupleExpr = index_expr::IndexTupleExpr; - - using GraphMatchCtx = ir_match::GraphMatchCtx; - - using Registry = registry::Registry; - - using ClassAttrs = axpr::ClassAttrs; - - using Function = axpr::Function; - - adt::Result ConvertFusionOpToClassAttrs( - const OpCodeGenCtx& op_code_gen_ctx, const IrOp& ir_op) { - using RetT = adt::Result; - return ir_op.Match( - [&](const PackedIrOp& packed_ir_op) -> RetT { - return PackedIrOpConvertFusionOpToClassAttrs(op_code_gen_ctx, - packed_ir_op); - }, - [&](const RefIrOp& ref_ir_op) -> RetT { - return RefIrOpConvertFusionOpToClassAttrs(op_code_gen_ctx, ref_ir_op); - }, - [&](const auto&) -> RetT { - return adt::errors::TypeError{ - std::string() + - "only packed ir op get supported in ConvertFusionOpToLambda."}; - }); - } - - adt::Result PackedIrOpConvertFusionOpToClassAttrs( - const OpCodeGenCtx& op_code_gen_ctx, const PackedIrOp& packed_ir_op) { - ADT_LET_CONST_REF( - index_tuple_expr, - GetPureElementwiseLoopIndexTupleExpr(op_code_gen_ctx, packed_ir_op)); - ADT_LET_CONST_REF( - ir_graph, - CreatePureElementwiseIndexedIrGraph(packed_ir_op, index_tuple_expr)); - ADT_LET_CONST_REF(init_func, - PackedIrOpMakeInitFuncByFusionOp( - op_code_gen_ctx, ir_graph, packed_ir_op)); - ADT_LET_CONST_REF(compute_func, - PackedIrOpMakeComputeFuncByFusionOp( - op_code_gen_ctx, ir_graph, packed_ir_op)); - ADT_LET_CONST_REF(load_from_register_func, - PackedIrOpMakeLoadFromRegisterFuncByFusionOp( - op_code_gen_ctx, ir_graph, packed_ir_op)); - ADT_LET_CONST_REF(store_to_register_func, - PackedIrOpMakeStoreToRegisterFuncByFusionOp( - op_code_gen_ctx, ir_graph, packed_ir_op)); - std::string class_name = "PackedIrOpClass"; - adt::List>> - empty_bases{}; - axpr::AttrMap methods{}; - methods->Set("__init__", init_func); - methods->Set("compute", compute_func); - methods->Set("load_from_register", load_from_register_func); - methods->Set("store_to_register", store_to_register_func); - return ClassAttrs{class_name, empty_bases, methods}; - } - - adt::Result GetPureElementwiseLoopIndexTupleExpr( - const OpCodeGenCtx& op_code_gen_ctx, const PackedIrOp& packed_ir_op) { - ADT_LET_CONST_REF( - shape, GetPureElementwiseLoopDimExpr(op_code_gen_ctx, packed_ir_op)); - return index_expr::IndexTupleExprDomain{shape}; - } - - adt::Result> GetPureElementwiseLoopDimExpr( - const OpCodeGenCtx& op_code_gen_ctx, const PackedIrOp& packed_ir_op) { - const auto& input_flags = op_code_gen_ctx->input_index_loop_anchor_flags; - { - ADT_LET_CONST_REF( - num_native_ir_inputs, - NumNativeIrInputBirValues(op_code_gen_ctx, packed_ir_op)); - ADT_CHECK(input_flags->size() == num_native_ir_inputs) - << adt::errors::TypeError{ - std::string() + - "len(input_index_loop_anchor_flags) should equal to number of " - "native ir inputs of fusion op. (" + - std::to_string(input_flags->size()) + " v.s. " + - std::to_string(num_native_ir_inputs) + ")"}; - } - const auto& output_flags = op_code_gen_ctx->output_index_loop_anchor_flags; - { - ADT_LET_CONST_REF( - num_native_ir_outputs, - NumNativeIrOutputBirValues(op_code_gen_ctx, packed_ir_op)); - ADT_CHECK(output_flags->size() == num_native_ir_outputs) - << adt::errors::TypeError{ - std::string() + - "len(output_index_loop_anchor_flags) should equal to number " - "of native ir outputs of fusion op. (" + - std::to_string(output_flags->size()) + " v.s. " + - std::to_string(num_native_ir_outputs) + ")"}; - } - using Shape = adt::List; - auto GetShape = [&](pir::Value value) -> adt::Result { - ADT_LET_CONST_REF(shape_ptr, NativeIrValue{value}.GetShapeDimExprsPtr()); - Shape shape; - shape->reserve(shape_ptr->size()); - shape->assign(shape_ptr->begin(), shape_ptr->end()); - return shape; - }; - std::optional opt_shape; - auto InitOrCheckShape = [&](pir::Value value) -> adt::Result { - ADT_LET_CONST_REF(shape, GetShape(value)); - if (opt_shape.has_value()) { - ADT_CHECK(opt_shape.value() == shape) << adt::errors::TypeError{ - "All loop anchors should have same shapes."}; - } else { - opt_shape = shape; - } - return adt::Ok{}; - }; - { - int input_idx = 0; - auto DoEachNativeInput = [&](pir::Value value) -> adt::Result { - if (input_flags->at(input_idx++).value()) { - ADT_RETURN_IF_ERR(InitOrCheckShape(value)); - } - return adt::Ok{}; - }; - ADT_RETURN_IF_ERR(VisitNativeIrInputBirValue( - op_code_gen_ctx, packed_ir_op, DoEachNativeInput)); - } - { - int output_idx = 0; - auto DoEachNativeOutput = [&](pir::Value value) -> adt::Result { - if (output_flags->at(output_idx++).value()) { - ADT_RETURN_IF_ERR(InitOrCheckShape(value)); - } - return adt::Ok{}; - }; - ADT_RETURN_IF_ERR(VisitNativeIrOutputBirValue( - op_code_gen_ctx, packed_ir_op, DoEachNativeOutput)); - } - ADT_CHECK(opt_shape.has_value()) << adt::errors::TypeError{ - "At least one flag should be set in input_index_loop_anchor_flags or " - "output_index_loop_anchor_flags"}; - return opt_shape.value(); - } - - adt::Result NumNativeIrInputBirValues( - const OpCodeGenCtx& op_code_gen_ctx, const PackedIrOp& packed_ir_op) { - std::size_t num_values = 0; - auto Acc = [&](pir::Value) -> adt::Result { - ++num_values; - return adt::Ok{}; - }; - ADT_RETURN_IF_ERR( - VisitNativeIrInputBirValue(op_code_gen_ctx, packed_ir_op, Acc)); - return num_values; - } - - adt::Result NumNativeIrOutputBirValues( - const OpCodeGenCtx& op_code_gen_ctx, const PackedIrOp& packed_ir_op) { - std::size_t num_values = 0; - auto Acc = [&](pir::Value) -> adt::Result { - ++num_values; - return adt::Ok{}; - }; - ADT_RETURN_IF_ERR( - VisitNativeIrOutputBirValue(op_code_gen_ctx, packed_ir_op, Acc)); - return num_values; - } - - adt::Result RefIrOpConvertFusionOpToClassAttrs( - const OpCodeGenCtx& op_code_gen_ctx, const RefIrOp& ref_ir_op) { - ADT_LET_CONST_REF( - init_func, RefIrOpMakeInitFuncByFusionOp(op_code_gen_ctx, ref_ir_op)); - ADT_LET_CONST_REF( - compute_func, - RefIrOpMakeComputeFuncByFusionOp(op_code_gen_ctx, ref_ir_op)); - ADT_LET_CONST_REF( - load_from_register_func, - RefIrOpMakeLoadFromRegisterFuncByFusionOp(op_code_gen_ctx, ref_ir_op)); - ADT_LET_CONST_REF( - store_to_register_func, - RefIrOpMakeStoreToRegisterFuncByFusionOp(op_code_gen_ctx, ref_ir_op)); - std::string class_name = "RefIrOpClass"; - adt::List>> - empty_bases{}; - axpr::AttrMap methods{}; - methods->Set("__init__", init_func); - methods->Set("compute", compute_func); - methods->Set("load_from_register", load_from_register_func); - methods->Set("store_to_register", load_from_register_func); - return ClassAttrs{class_name, empty_bases, methods}; - } - - adt::Result PackedIrOpMakeInitFuncByFusionOp( - const OpCodeGenCtx& op_code_gen_ctx, - const IndexedIrGraph& ir_graph, - const PackedIrOp& packed_ir_op) { - return ir_graph.Match([&](const auto& impl) -> adt::Result { - return PackedIrOpMakeInitFuncByFusionOpImpl( - op_code_gen_ctx, impl, packed_ir_op); - }); - } - - adt::Result PackedIrOpMakeStoreToRegisterFuncByFusionOp( - const OpCodeGenCtx& op_code_gen_ctx, - const IndexedIrGraph& ir_graph, - const PackedIrOp& packed_ir_op) { - return ir_graph.Match([&](const auto& impl) -> adt::Result { - return PackedIrOpMakeStoreToRegisterFuncByFusionOpImpl( - op_code_gen_ctx, impl, packed_ir_op); - }); - } - - adt::Result PackedIrOpMakeLoadFromRegisterFuncByFusionOp( - const OpCodeGenCtx& op_code_gen_ctx, - const IndexedIrGraph& ir_graph, - const PackedIrOp& packed_ir_op) { - return ir_graph.Match([&](const auto& impl) -> adt::Result { - return PackedIrOpMakeLoadFromRegisterFuncByFusionOpImpl( - op_code_gen_ctx, impl, packed_ir_op); - }); - } - - adt::Result PackedIrOpMakeLoadFromRegisterFuncByFusionOpImpl( - const OpCodeGenCtx& op_code_gen_ctx, - const PureElementwiseIndexedIrGraph& ir_graph, - const PackedIrOp& packed_ir_op) { - axpr::LambdaExprBuilder lmbd; - auto GetMapFunc = [&](auto& ctx) -> axpr::AnfExpr { - auto& value_class_var = - ctx.Var("self").Attr("class_factory").Attr("get_value_class").Call(); - auto& name_var = ctx.Var("indexed_ir_node_info_tuple").At(0); - auto& index_tuple_expr_var = ctx.Var("indexed_ir_node_info_tuple").At(1); - auto& dtype_var = ctx.Var("indexed_ir_node_info_tuple").At(2); - auto& input_var = value_class_var.Call( - index_tuple_expr_var, dtype_var, ctx.Var("input_local_var_name")); - return ctx.Var(axpr::kBuiltinList()).Call(name_var, input_var); - }; - using AnfExprs = std::vector; - auto GetAllInputIndexedIrNodeInfo = - [&](auto* ctx) -> adt::Result { - AnfExprs ret; - auto DoEachNativeIrValue = - [&](pir::Value ir_value) -> adt::Result { - AnfExprs indexed_ir_info_tuple; - ADT_LET_CONST_REF(dtype, ConvertToDataType(ir_value)); - for (const auto& input : ir_graph->inputs) { - if (input->value == ir_value) { - auto& info_var = - ctx->Var(axpr::kBuiltinList()) - .Call(ctx->String(input->GetUniqueNameInsideNodeArena()), - ctx->Var("self").Attr("loop_index_tuple_expr"), - ctx->Var("DataType").Attr(dtype.Name())); - indexed_ir_info_tuple.emplace_back( - static_cast(info_var)); - } - } - auto& indexed_ir_info_var = - ctx->Var(axpr::kBuiltinList()).Apply(indexed_ir_info_tuple); - ret.emplace_back(static_cast(indexed_ir_info_var)); - return adt::Ok{}; - }; - ADT_RETURN_IF_ERR(VisitNativeIrInputBirValue( - op_code_gen_ctx, packed_ir_op, DoEachNativeIrValue)); - return ret; - }; - auto GetBody = [&](auto& ctx) -> adt::Result { - const auto& map_func_var_name = ctx.NewTmpVarName(); - ctx.Var(map_func_var_name) = - lmbd.Lambda({"indexed_ir_node_info_tuple"}, GetMapFunc); - ADT_LET_CONST_REF(indexed_nodes, GetAllInputIndexedIrNodeInfo(&ctx)); - auto& indexed_nodes_var = - ctx.Var(axpr::kBuiltinList()).Apply(indexed_nodes); - auto& native_input_indexed_nodes_var = - indexed_nodes_var.At(ctx.Var("native_input_index")); - auto& items_var = ctx.Var("map").Call(ctx.Var(map_func_var_name), - native_input_indexed_nodes_var); - auto& ret = ctx.Var("OrderedDict").Call(items_var); - return static_cast>(ret); - }; - ADT_LET_CONST_REF(anf_expr, - lmbd.TryLambda({"self", - "code_gen_ctx", - "input_local_var_name", - "native_input_index"}, - GetBody)); - const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); - ADT_LET_CONST_REF( - atomic, core_expr.template TryGet>()); - ADT_LET_CONST_REF(lambda, - atomic.template TryGet>()); - return Function{lambda, std::nullopt}; - } - - adt::Result PackedIrOpMakeStoreToRegisterFuncByFusionOpImpl( - const OpCodeGenCtx& op_code_gen_ctx, - const PureElementwiseIndexedIrGraph& ir_graph, - const PackedIrOp& packed_ir_op) { - ADT_CHECK(ir_graph->yield_op_inputs.size() == ir_graph->outputs.size()); - auto GetOutputIndex = - [&](pir::Value output) -> adt::Result> { - for (int i = 0; i < ir_graph->outputs.size(); ++i) { - if (output == ir_graph->outputs.at(i)) { - return i; - } - } - return std::nullopt; - }; - axpr::LambdaExprBuilder lmbd; - using AnfExprs = std::vector; - auto GetAllOutputIndexedIrNodeInfo = - [&](auto* ctx) -> adt::Result { - AnfExprs ret; - auto DoEachNativeIrValue = - [&](pir::Value ir_value) -> adt::Result { - ADT_LET_CONST_REF(dtype, ConvertToDataType(ir_value)); - ADT_LET_CONST_REF(opt_idx, GetOutputIndex(ir_value)); - ADT_CHECK(opt_idx.has_value()); - const auto& output = ir_graph->yield_op_inputs.at(opt_idx.value()); - auto& indexed_ir_info_tuple = - ctx->Var(axpr::kBuiltinList()) - .Call(ctx->String(output->GetUniqueNameInsideNodeArena()), - ctx->Var("self").Attr("loop_index_tuple_expr"), - ctx->Var("DataType").Attr(dtype.Name())); - ret.emplace_back(indexed_ir_info_tuple); - return adt::Ok{}; - }; - ADT_RETURN_IF_ERR(VisitNativeIrOutputBirValue( - op_code_gen_ctx, packed_ir_op, DoEachNativeIrValue)); - return ret; - }; - - auto GetBody = [&](auto& ctx) -> adt::Result { - ADT_LET_CONST_REF(indexed_nodes, GetAllOutputIndexedIrNodeInfo(&ctx)); - auto& indexed_nodes_var = - ctx.Var(axpr::kBuiltinList()).Apply(indexed_nodes); - auto& native_output_indexed_node_var = - indexed_nodes_var.At(ctx.Var("native_output_index")); - auto& name_var = native_output_indexed_node_var.At(0); - auto& output_var = ctx.Var("compute_results").At(name_var); - auto& value_class_var = - ctx.Var("self").Attr("class_factory").Attr("get_value_class").Call(); - auto& index_tuple_expr_var = native_output_indexed_node_var.At(1); - auto& dtype_var = native_output_indexed_node_var.At(2); - auto& store_var = value_class_var.Call( - index_tuple_expr_var, dtype_var, ctx.Var("out_value_local_var_name")); - ctx.Var("code_gen_ctx").Attr("assign").Call(store_var, output_var); - return ctx.None(); - }; - ADT_LET_CONST_REF(anf_expr, - lmbd.TryLambda({"self", - "code_gen_ctx", - "compute_results", - "out_value_local_var_name", - "native_output_index"}, - GetBody)); - const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); - ADT_LET_CONST_REF( - atomic, core_expr.template TryGet>()); - ADT_LET_CONST_REF(lambda, - atomic.template TryGet>()); - return Function{lambda, std::nullopt}; - } - - adt::Result PackedIrOpMakeComputeFuncByFusionOp( - const OpCodeGenCtx& op_code_gen_ctx, - const IndexedIrGraph& ir_graph, - const PackedIrOp& packed_ir_op) { - return ir_graph.Match([&](const auto& impl) -> adt::Result { - return PackedIrOpMakeComputeFuncByFusionOpImpl( - op_code_gen_ctx, impl, packed_ir_op); - }); - } - - adt::Result PackedIrOpMakeComputeFuncByFusionOpImpl( - const OpCodeGenCtx& op_code_gen_ctx, - const PureElementwiseIndexedIrGraph& ir_graph, - const PackedIrOp& packed_ir_op) { - axpr::LambdaExprBuilder lmbd; - using Ok = adt::Result; - auto UnpackInputs = [&](auto* ctx) -> Ok { - for (const auto& input : ir_graph->inputs) { - const auto& name = input->GetUniqueNameInsideNodeArena(); - ctx->Var(name) = ctx->Var("inputs").At(ctx->String(name)); - } - return adt::Ok{}; - }; - auto ComputeNativeOpCodeGen = [&](auto* ctx, - const auto& indexed_ir_op) -> Ok { - ADT_LET_CONST_REF(input_var_names, GetInputVarNames(indexed_ir_op)); - const auto& indexed_ir_op_name = - indexed_ir_op->GetUniqueNameInsideNodeArena(); - ADT_LET_CONST_REF(output_var_names, GetOutputVarNames(indexed_ir_op)); - std::vector args{ctx->Var("code_gen_ctx")}; - args.reserve(input_var_names.size() + 1); - for (const auto& input_var_name : input_var_names) { - args.push_back(ctx->Var(input_var_name)); - } - auto& outputs_var = ctx->Var("self").Attr(indexed_ir_op_name).Apply(args); - for (int i = 0; i < output_var_names.size(); ++i) { - const auto& output_var_name = output_var_names.at(i); - ctx->Var(output_var_name) = outputs_var.At(i); - } - return adt::Ok{}; - }; - auto PackedOutputs = [&](auto* ctx) -> adt::Result { - std::vector yield_op_input_items; - yield_op_input_items.reserve(ir_graph->yield_op_inputs.size()); - for (const auto& yield_op_input : ir_graph->yield_op_inputs) { - const auto& name = yield_op_input->GetUniqueNameInsideNodeArena(); - const auto& pair = ctx->Var(axpr::kBuiltinList()) - .Call(ctx->String(name), ctx->Var(name)); - yield_op_input_items.emplace_back(static_cast(pair)); - } - const auto& items = - ctx->Var(axpr::kBuiltinList()).Call(yield_op_input_items); - return ctx->Call("OrderedDict", items); - }; - auto GetBody = [&](auto& ctx) -> adt::Result { - auto* ctx_ptr = &ctx; - ADT_RETURN_IF_ERR(UnpackInputs(ctx_ptr)); - ADT_RETURN_IF_ERR( - VisitIndexedIrOp(ir_graph, [&](const auto& indexed_ir_op) -> Ok { - return ComputeNativeOpCodeGen(ctx_ptr, indexed_ir_op); - })); - ADT_LET_CONST_REF(packed_outputs, PackedOutputs(ctx_ptr)); - return packed_outputs; - }; - std::vector arg_names{"self", "code_gen_ctx", "inputs"}; - ADT_LET_CONST_REF(anf_expr, lmbd.TryLambda(arg_names, GetBody)); - const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); - ADT_LET_CONST_REF( - atomic, core_expr.template TryGet>()); - ADT_LET_CONST_REF(lambda, - atomic.template TryGet>()); - return Function{lambda, std::nullopt}; - } - - adt::Result> GetInputVarNames( - const IndexedIrOp& indexed_ir_op) const { - ADT_LET_CONST_REF(upstreams, indexed_ir_op->node.UpstreamNodes()); - std::vector ret{}; - ret.reserve(upstreams.size()); - auto DoEach = [&](const auto& node) -> adt::Result { - ADT_LET_CONST_REF(ir_node, node.Get()); - ADT_LET_CONST_REF( - ir_value, ir_node.template TryGet>()); - ret.push_back(ir_value->GetUniqueNameInsideNodeArena()); - return adt::Ok{}; - }; - ADT_RETURN_IF_ERR(upstreams.VisitNodes(DoEach)); - return ret; - } - - adt::Result> GetOutputVarNames( - const IndexedIrOp& indexed_ir_op) { - ADT_LET_CONST_REF(downstreams, indexed_ir_op->node.DownstreamNodes()); - std::vector ret{}; - ret.reserve(downstreams.size()); - auto DoEach = [&](const auto& node) -> adt::Result { - ADT_LET_CONST_REF(ir_node, node.Get()); - ADT_LET_CONST_REF( - ir_value, ir_node.template TryGet>()); - ret.push_back(ir_value->GetUniqueNameInsideNodeArena()); - return adt::Ok{}; - }; - ADT_RETURN_IF_ERR(downstreams.VisitNodes(DoEach)); - return ret; - } - - adt::Result PackedIrOpMakeInitFuncByFusionOpImpl( - const OpCodeGenCtx& op_code_gen_ctx, - const PureElementwiseIndexedIrGraph& ir_graph, - const PackedIrOp& packed_ir_op) { - axpr::LambdaExprBuilder lmbd; - using Ok = adt::Result; - auto ConstructNativeOpCodeGen = [&](auto* ctx, - const auto& indexed_ir_op) -> Ok { - const auto& op_name = indexed_ir_op->op->name(); - auto& class_var = ctx->Var("get_native_op_code_generator_class") - .Call(ctx->String(op_name)); - { - std::vector input_dtype_anf_exprs; - for (int i = 0; i < indexed_ir_op->op->num_operands(); ++i) { - ADT_LET_CONST_REF( - dtype, ConvertToDataType(indexed_ir_op->op->operand_source(i))); - const auto& dtype_var = ctx->Var("DataType").Attr(dtype.Name()); - input_dtype_anf_exprs.emplace_back( - static_cast(dtype_var)); - } - ctx->Var("input_dtypes") = - ctx->Call(axpr::kBuiltinList(), input_dtype_anf_exprs); - } - { - std::vector output_dtype_anf_exprs; - for (int i = 0; i < indexed_ir_op->op->num_results(); ++i) { - ADT_LET_CONST_REF(dtype, - ConvertToDataType(indexed_ir_op->op->result(i))); - const auto& dtype_var = ctx->Var("DataType").Attr(dtype.Name()); - output_dtype_anf_exprs.emplace_back( - static_cast(dtype_var)); - } - ctx->Var("output_dtypes") = - ctx->Call(axpr::kBuiltinList(), output_dtype_anf_exprs); - } - { - std::vector input_index_tuple_exprs; - input_index_tuple_exprs.reserve(indexed_ir_op->op->num_operands()); - for (int i = 0; i < indexed_ir_op->op->num_operands(); ++i) { - input_index_tuple_exprs.emplace_back( - ctx->Var("loop_index_tuple_expr")); - } - ctx->Var("input_index_tuple_exprs") = - ctx->Call(axpr::kBuiltinList(), input_index_tuple_exprs); - } - { - std::vector output_index_tuple_exprs; - output_index_tuple_exprs.reserve(indexed_ir_op->op->num_results()); - for (int i = 0; i < indexed_ir_op->op->num_results(); ++i) { - output_index_tuple_exprs.emplace_back( - ctx->Var("loop_index_tuple_expr")); - } - ctx->Var("output_index_tuple_exprs") = - ctx->Call(axpr::kBuiltinList(), output_index_tuple_exprs); - } - const auto& indexed_op_name = - indexed_ir_op->GetUniqueNameInsideNodeArena(); - axpr::AnfExpr indexed_op = - class_var.Call(ctx->Var("index_expr_code_gen"), - ctx->String(indexed_op_name), - ctx->Var("input_dtypes"), - ctx->Var("output_dtypes"), - ctx->Var("input_index_tuple_exprs"), - ctx->Var("output_index_tuple_exprs"), - /*attrs*/ ctx->None()); - ctx->Var("self").SetAttr(indexed_op_name, indexed_op); - return adt::Ok{}; - }; - auto GetBody = [&](auto& ctx) -> adt::Result { - ctx.Var("self").SetAttr("class_factory", ctx.Var("class_factory")); - ctx.Var("self").SetAttr("loop_index_tuple_expr", - ctx.Var("loop_index_tuple_expr")); - ctx.Var("index_expr_code_generator_class") = - ctx.Var("class_factory") - .Attr("get_index_expr_code_generator_class") - .Call(); - ctx.Var("index_expr_code_gen") = - ctx.Var("index_expr_code_generator_class") - .Call(ctx.Var("loop_var_names")); - ctx.Var("get_native_op_code_generator_class") = - ctx.Var("class_factory") - .Attr("get_native_op_code_generator_class") - .Call(); - auto* ctx_ptr = &ctx; - ADT_RETURN_IF_ERR( - VisitIndexedIrOp(ir_graph, [&](const auto& indexed_ir_op) -> Ok { - return ConstructNativeOpCodeGen(ctx_ptr, indexed_ir_op); - })); - return ctx.None(); - }; - ADT_LET_CONST_REF(anf_expr, - lmbd.TryLambda({"self", - "class_factory", - "loop_index_tuple_expr", - "loop_var_names"}, - GetBody)); - const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); - ADT_LET_CONST_REF( - atomic, core_expr.template TryGet>()); - ADT_LET_CONST_REF(lambda, - atomic.template TryGet>()); - return Function{lambda, std::nullopt}; - } - - template - adt::Result VisitIndexedIrOp( - const PureElementwiseIndexedIrGraph& ir_graph, - const DoEachIndexIrNodeT& DoEachIndexIrNode) { - for (const auto& node : ir_graph->node_arena->nodes()) { - if (node.template Has>()) { - ADT_RETURN_IF_ERR( - DoEachIndexIrNode(node.template Get>())); - } - } - return adt::Ok{}; - } - - adt::Result RefIrOpMakeInitFuncByFusionOp( - const OpCodeGenCtx& op_code_gen_ctx, const RefIrOp& ref_ir_op) { - axpr::LambdaExprBuilder lmbd; - auto GetBody = [](auto& ctx) { - ctx.Var("self").SetAttr("class_factory", ctx.Var("class_factory")); - ctx.Var("self").SetAttr("loop_index_tuple_expr", - ctx.Var("loop_index_tuple_expr")); - return ctx.None(); - }; - const auto& anf_expr = lmbd.Lambda( - {"self", "class_factory", "loop_index_tuple_expr", "loop_var_names"}, - GetBody); - const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); - ADT_LET_CONST_REF( - atomic, core_expr.template TryGet>()); - ADT_LET_CONST_REF(lambda, - atomic.template TryGet>()); - return Function{lambda, std::nullopt}; - } - - adt::Result RefIrOpMakeComputeFuncByFusionOp( - const OpCodeGenCtx& op_code_gen_ctx, const RefIrOp& ref_ir_op) { - axpr::LambdaExprBuilder lmbd; - auto GetBody = [](auto& ctx) -> axpr::AnfExpr { return ctx.Var("inputs"); }; - const auto& anf_expr = - lmbd.Lambda({"self", "code_gen_ctx", "inputs"}, GetBody); - const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); - ADT_LET_CONST_REF( - atomic, core_expr.template TryGet>()); - ADT_LET_CONST_REF(lambda, - atomic.template TryGet>()); - return Function{lambda, std::nullopt}; - } - - adt::Result RefIrOpMakeLoadFromRegisterFuncByFusionOp( - const OpCodeGenCtx& op_code_gen_ctx, const RefIrOp& ref_ir_op) { - pir::Value value = ref_ir_op.ref_node_info->ir_value.value; - ADT_LET_CONST_REF(dtype, ConvertToDataType(value)); - axpr::LambdaExprBuilder lmbd; - auto GetBody = [&](auto& ctx) { - auto& value_class_var = - ctx.Var("self").Attr("class_factory").Attr("get_value_class").Call(); - auto& index_tuple_expr_var = - ctx.Var("self").Attr("loop_index_tuple_expr"); - auto& dtype_var = ctx.Var("DataType").Attr(dtype.Name()); - auto& input_var = value_class_var.Call( - index_tuple_expr_var, dtype_var, ctx.Var("input_local_var_name")); - return ctx.Var("OrderedDict") - .Call(ctx.Var(axpr::kBuiltinList()) - .Call(ctx.Var(axpr::kBuiltinList()) - .Call(ctx.String("sole_ir_value"), input_var))); - }; - const auto& anf_expr = lmbd.Lambda( - {"self", "code_gen_ctx", "input_local_var_name", "native_input_index"}, - GetBody); - const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); - ADT_LET_CONST_REF( - atomic, core_expr.template TryGet>()); - ADT_LET_CONST_REF(lambda, - atomic.template TryGet>()); - return Function{lambda, std::nullopt}; - } - - adt::Result RefIrOpMakeStoreToRegisterFuncByFusionOp( - const OpCodeGenCtx& op_code_gen_ctx, const RefIrOp& ref_ir_op) { - pir::Value value = ref_ir_op.ref_node_info->ir_value.value; - ADT_LET_CONST_REF(dtype, ConvertToDataType(value)); - axpr::LambdaExprBuilder lmbd; - auto GetBody = [&](auto& ctx) { - auto& value_class_var = - ctx.Var("self").Attr("class_factory").Attr("get_value_class").Call(); - auto& index_tuple_expr_var = - ctx.Var("self").Attr("loop_index_tuple_expr"); - auto& dtype_var = ctx.Var("DataType").Attr(dtype.Name()); - auto& output_var = value_class_var.Call( - index_tuple_expr_var, dtype_var, ctx.Var("out_value_local_var_name")); - ctx.Var("code_gen_ctx") - .Attr("assign") - .Call(output_var, - ctx.Var("compute_results").At(ctx.String("sole_ir_value"))); - return ctx.None(); - }; - const auto& anf_expr = lmbd.Lambda({"self", - "code_gen_ctx", - "compute_results", - "out_value_local_var_name", - "native_output_index"}, - GetBody); - const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); - ADT_LET_CONST_REF( - atomic, core_expr.template TryGet>()); - ADT_LET_CONST_REF(lambda, - atomic.template TryGet>()); - return Function{lambda, std::nullopt}; - } - - using NativeOrRefIrValue = ir_match::NativeOrRefIrValue; - - template - adt::Result VisitNativeIrInputBirValue( - const OpCodeGenCtx& op_code_gen_ctx, - const PackedIrOp& packed_ir_op, - const DoEachT& DoEach) { - ADT_LET_CONST_REF(graph_match_ctx, GetGraphMatchCtx(op_code_gen_ctx)); - ADT_LET_CONST_REF(drr_trivial_fusion_ir_op, - GetDrrTrivialFusionIrOp(graph_match_ctx, packed_ir_op)); - auto DoEachNativeValue = - [&](const auto& drr_ir_value) -> adt::Result { - ADT_LET_CONST_REF(value, GetPirValue(graph_match_ctx, drr_ir_value)); - return DoEach(value); - }; - auto DoEachPackedValue = - [&](const auto& drr_ir_value) -> adt::Result { - // Do nothing. - return adt::Ok{}; - }; - return VisitDrrTrivialFusionIrOpInput( - drr_trivial_fusion_ir_op, DoEachNativeValue, DoEachPackedValue); - } - - template - adt::Result VisitInputBirNativeIrValue( - const OpCodeGenCtx& op_code_gen_ctx, - const PackedIrOp& packed_ir_op, - const DoEachT& DoEach) { - ADT_LET_CONST_REF(graph_match_ctx, GetGraphMatchCtx(op_code_gen_ctx)); - ADT_LET_CONST_REF(drr_trivial_fusion_ir_op, - GetDrrTrivialFusionIrOp(graph_match_ctx, packed_ir_op)); - auto DoEachNativeValue = - [&](const auto& drr_ir_value) -> adt::Result { - ADT_LET_CONST_REF(value, GetPirValue(graph_match_ctx, drr_ir_value)); - return DoEach(value); - }; - auto DoEachPackedValue = - [&](const auto& drr_ir_value) -> adt::Result { - ADT_RETURN_IF_ERR( - VisitPackedPirValue(graph_match_ctx, drr_ir_value, DoEach)); - return adt::Ok{}; - }; - return VisitDrrTrivialFusionIrOpInput( - drr_trivial_fusion_ir_op, DoEachNativeValue, DoEachPackedValue); - } - - template - adt::Result VisitPackedPirValue(const GraphMatchCtx& match_ctx, - const DrrPackedIrValue& drr_ir_value, - const DoEachT& DoEach) { - auto DoEachPirNode = [&](const PirNode& pir_node) -> adt::Result { - ADT_LET_CONST_REF(pir_value, pir_node.template TryGet()); - ADT_RETURN_IF_ERR(DoEach(pir_value.value)); - return adt::Ok{}; - }; - const auto& node = drr_ir_value->node; - ADT_RETURN_IF_ERR( - match_ctx->VisitPackedBigGraphIrValueNode(node, DoEachPirNode)); - return adt::Ok{}; - } - - template - adt::Result VisitNativeIrOutputBirValue( - const OpCodeGenCtx& op_code_gen_ctx, - const PackedIrOp& packed_ir_op, - const DoEachT& DoEach) { - ADT_LET_CONST_REF(graph_match_ctx, GetGraphMatchCtx(op_code_gen_ctx)); - ADT_LET_CONST_REF(drr_trivial_fusion_ir_op, - GetDrrTrivialFusionIrOp(graph_match_ctx, packed_ir_op)); - auto DoEachNativeValue = - [&](const auto& drr_ir_value) -> adt::Result { - ADT_LET_CONST_REF(value, GetPirValue(graph_match_ctx, drr_ir_value)); - return DoEach(value); - }; - auto DoEachPackedValue = - [&](const auto& drr_ir_value) -> adt::Result { - // Do nothing. - return adt::Ok{}; - }; - return VisitDrrTrivialFusionIrOpOutput( - drr_trivial_fusion_ir_op, DoEachNativeValue, DoEachPackedValue); - } - - template - adt::Result VisitOutputNativeIrValue( - const OpCodeGenCtx& op_code_gen_ctx, - const PackedIrOp& packed_ir_op, - const DoEachT& DoEach) { - ADT_LET_CONST_REF(graph_match_ctx, GetGraphMatchCtx(op_code_gen_ctx)); - ADT_LET_CONST_REF(drr_trivial_fusion_ir_op, - GetDrrTrivialFusionIrOp(graph_match_ctx, packed_ir_op)); - auto DoEachNativeValue = - [&](const auto& drr_ir_value) -> adt::Result { - ADT_LET_CONST_REF(value, GetPirValue(graph_match_ctx, drr_ir_value)); - return DoEach(value); - }; - auto DoEachPackedValue = - [&](const auto& drr_ir_value) -> adt::Result { - ADT_RETURN_IF_ERR( - VisitPackedPirValue(graph_match_ctx, drr_ir_value, DoEach)); - return adt::Ok{}; - }; - return VisitDrrTrivialFusionIrOpOutput( - drr_trivial_fusion_ir_op, DoEachNativeValue, DoEachPackedValue); - } - - template - adt::Result VisitDrrTrivialFusionIrOpInput( - const DrrTrivialFusionIrOp& drr_trivial_fusion_ir_op, - const DoEachNativeValueT& DoEachNativeValue, - const DoEachPackedValueT DoEachPackedValue) { - LOG(ERROR) << "drr_trivial_fusion_ir_op: " - << graph::NodeDescriptor{}.DebugId( - drr_trivial_fusion_ir_op.node()); - auto DoEach = [&](const DrrGraphNode& node) -> adt::Result { - ADT_LET_CONST_REF(drr_node, node.Get()); - LOG(ERROR) << "drr_trivial_fusion_ir_op input: " - << graph::NodeDescriptor{}.DebugId(node); - return drr_node.Match( - [&](const DrrNativeIrValue& ir_value) -> adt::Result { - return DoEachNativeValue(ir_value); - }, - [&](const DrrPackedIrValue& ir_value) -> adt::Result { - return DoEachPackedValue(ir_value); - }, - [&](const auto&) -> adt::Result { - return adt::errors::ValueError{ - "the second connected upstreams of drr packed ir op should be " - "drr native ir values or drr packed ir values."}; - }); - }; - return VisitSecondConnectedUpstream(drr_trivial_fusion_ir_op.node(), - DoEach); - } - - template - adt::Result VisitDrrTrivialFusionIrOpOutput( - const DrrTrivialFusionIrOp& drr_trivial_fusion_ir_op, - const DoEachNativeValueT& DoEachNativeValue, - const DoEachPackedValueT DoEachPackedValue) { - auto DoEach = [&](const DrrGraphNode& node) -> adt::Result { - ADT_LET_CONST_REF(drr_node, node.Get()); - return drr_node.Match( - [&](const DrrNativeIrValue& ir_value) -> adt::Result { - return DoEachNativeValue(ir_value); - }, - [&](const DrrPackedIrValue& ir_value) -> adt::Result { - return DoEachPackedValue(ir_value); - }, - [&](const auto&) -> adt::Result { - return adt::errors::ValueError{ - "the second connected upstreams of drr packed ir op should be " - "drr native ir values or drr packed ir values."}; - }); - }; - return VisitSecondConnectedDownstream(drr_trivial_fusion_ir_op.node(), - DoEach); - } - - template - adt::Result VisitSecondConnectedUpstream(const DrrGraphNode& node, - const DoEachT& DoEach) { - auto DoEachUpstream = [&](const auto& upstream) -> adt::Result { - return VisitUpstream(upstream, DoEach); - }; - return VisitUpstream(node, DoEachUpstream); - } - - template - adt::Result VisitSecondConnectedDownstream(const DrrGraphNode& node, - const DoEachT& DoEach) { - auto DoEachUpstream = [&](const auto& downstream) -> adt::Result { - return VisitDownstream(downstream, DoEach); - }; - return VisitDownstream(node, DoEachUpstream); - } - - template - adt::Result VisitUpstream(const DrrGraphNode& node, - const DoEachT& DoEach) { - ADT_LET_CONST_REF(upstreams, node.UpstreamNodes()); - return upstreams.VisitNodes(DoEach); - } - - template - adt::Result VisitDownstream(const DrrGraphNode& node, - const DoEachT& DoEach) { - ADT_LET_CONST_REF(downstreams, node.DownstreamNodes()); - return downstreams.VisitNodes(DoEach); - } - - adt::Result GetPirValue( - const GraphMatchCtx& graph_match_ctx, - const DrrNativeIrValue& drr_native_ir_value) { - const auto& node = drr_native_ir_value->node; - ADT_LET_CONST_REF(pir_node, graph_match_ctx->GetSoleBigGraphNode(node)); - ADT_LET_CONST_REF(pir_value, pir_node.template TryGet()); - return pir_value.value; - } - - adt::Result GetDrrTrivialFusionIrOp( - const GraphMatchCtx& graph_match_ctx, const PackedIrOp& packed_ir_op) { - ADT_LET_CONST_REF(node, - graph_match_ctx->GetMatchedSmallGraphNode(packed_ir_op)); - ADT_LET_CONST_REF(drr_node, node.Get()); - using RetT = adt::Result; - return drr_node.Match( - [&](const DrrPackedIrOp& impl) -> RetT { return impl; }, - [&](const DrrOptPackedIrOp& impl) -> RetT { return impl; }, - [&](const auto&) -> RetT { - return adt::errors::NotImplementedError{ - "conversion from DrrNode to DrrTrivialFusionIrOp failed."}; - }); - } - - adt::Result GetGraphMatchCtx( - const OpCodeGenCtx& op_code_gen_ctx) const { - ADT_LET_CONST_REF(code_gen_ctx, - adt::WeakPtrLock(op_code_gen_ctx->code_gen_ctx)); - ADT_CHECK(code_gen_ctx->ir_match_ctx.has_value()); - const auto& ir_match_ctx = code_gen_ctx->ir_match_ctx.value(); - return ir_match_ctx->graph_match_ctx; - } - - adt::Result GetConstDataPointerType(pir::Value value) { - ADT_LET_CONST_REF(data_type, ConvertToDataType(value)); - return axpr::GetConstPointerTypeFromDataType(data_type); - } - - adt::Result GetMutableDataPointerType(pir::Value value) { - ADT_LET_CONST_REF(data_type, ConvertToDataType(value)); - return axpr::GetMutablePointerTypeFromDataType(data_type); - } - - adt::Result ConvertToDataType(pir::Value value) { - ADT_LET_CONST_REF(dtype, ConvertToPhiDataType(value)); - return ap::axpr::GetDataTypeFromPhiDataType(dtype); - } - - adt::Result ConvertToPhiDataType(pir::Value value) { - ADT_LET_CONST_REF(type, GetPirDataType(value)); - try { - return ::paddle::dialect::TransToPhiDataType(type); - } catch (const std::exception& e) { - return adt::errors::TypeError{ - "failed to cast from pir data type to phi data type."}; - } - } - - adt::Result GetPirDataType(pir::Value value) { - if (!value.type().isa()) { - return adt::errors::NotImplementedError{ - "pir value must be of DenseTensorType"}; - } - const auto dense_tensor_type = - value.type().dyn_cast(); - return dense_tensor_type.dtype(); - } -}; - -} // namespace ap::paddle - -namespace ap::code_gen { - -template <> -struct OpCudaCodeGenImpl - : public paddle::OpCudaCodeGenImpl {}; - -} // namespace ap::code_gen diff --git a/paddle/ap/src/index_expr/index_closure.cc b/paddle/ap/src/index_expr/index_closure.cc deleted file mode 100644 index 84b2aba34b6aa5..00000000000000 --- a/paddle/ap/src/index_expr/index_closure.cc +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/ap/include/index_expr/index_closure.h" -#include "paddle/ap/include/index_expr/op_index_tuple_expr_signature.h" -#include "paddle/ap/include/index_expr/valid_index_expr_builder.h" - -namespace ap::index_expr { - -adt::Result OrderedOneofIndexClosureImpl::operator()( - const IndexTupleExpr& indexes_expr) const { - size_t count = 0; - for (const auto& [_, lambdas] : nice2index_lambdas) { - for (const auto& lambda : lambdas) { - const auto& res = CallLambda(lambda, indexes_expr); - if (res.Has()) { - return res.Get(); - } - ++count; - } - } - return adt::errors::ValueError{ - std::string() + - "all index closure failed. tried count: " + std::to_string(count)}; -} - -adt::Result OrderedOneofIndexClosureImpl::CallLambda( - const Lambda& lambda, const IndexTupleExpr& indexes_expr) const { - axpr::BuiltinClassInstance instance{GetIndexTupleExprClass(), - indexes_expr}; - const std::vector args{closure_data.ctx, - closure_data.inputs_meta, - closure_data.outputs_meta, - closure_data.in_vars, - Val{instance}}; - const auto& opt_ret = (*this->interpreter)(lambda, args); - ADT_RETURN_IF_ERR(opt_ret); - const auto& ret = opt_ret.GetOkValue(); - return ret.template CastTo(); -} - -namespace { - -template -adt::Result OpIndexesTransformApply( - const OpIndexesTransformSignature& indexes_transform_signature, - const IndexesTransformApplyT& IndexesTransformApply) { - InIndexTupleExprSignature in_sig; - for (const auto& transform : - *indexes_transform_signature.in_signature.descriptors) { - const auto& converted = IndexesTransformApply(transform); - ADT_RETURN_IF_ERR(converted); - in_sig.descriptors->emplace_back(converted.GetOkValue()); - } - OutIndexTupleExprSignature out_sig; - for (const auto& transform : - *indexes_transform_signature.out_signature.descriptors) { - const auto& converted = IndexesTransformApply(transform); - ADT_RETURN_IF_ERR(converted); - out_sig.descriptors->emplace_back(converted.GetOkValue()); - } - return OpIndexTupleExprSignature{in_sig, out_sig}; -} - -} // namespace - -adt::Result RecordableIndexClosureImpl::operator()( - const IndexTupleExpr& indexes_expr) const { - const auto& ApplyTransform = [&](const TrackedIndexesTransform& transform) { - return transform.Match( - [&](const adt::IdentityFunc&) -> adt::Result { - return indexes_expr; - }, - [&](const IndexTupleExpr& tacked_indexes_expr_as_func) - -> adt::Result { - return ValidIndexExprBuilder().Compose(tacked_indexes_expr_as_func, - indexes_expr); - }); - }; - return OpIndexesTransformApply(this->op_indexes_transform_signature, - ApplyTransform); -} - -adt::Result IndexClosure::operator()( - const IndexTupleExpr& indexes_expr) const { - return Match([&](const auto& impl) { return (*impl)(indexes_expr); }); -} - -} // namespace ap::index_expr diff --git a/paddle/ap/src/index_expr/index_expr_builtin_functions.cc b/paddle/ap/src/index_expr/index_expr_builtin_functions.cc deleted file mode 100644 index eef5d8227608dc..00000000000000 --- a/paddle/ap/src/index_expr/index_expr_builtin_functions.cc +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/ap/include/axpr/builtin_functions.h" -#include "paddle/ap/include/index_expr/index_expr_util.h" -#include "paddle/ap/include/index_expr/valid_index_expr_builder.h" -#include "paddle/ap/include/index_expr/value.h" -#include "paddle/ap/include/index_expr/value_method_class.h" - -namespace ap::index_expr {} // namespace ap::index_expr diff --git a/paddle/ap/src/index_expr/index_expr_util.cc b/paddle/ap/src/index_expr/index_expr_util.cc deleted file mode 100644 index a8a788ba9d6a00..00000000000000 --- a/paddle/ap/src/index_expr/index_expr_util.cc +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/ap/include/index_expr/index_expr_util.h" -#include "paddle/ap/include/axpr/adt.h" -#include "paddle/ap/include/index_expr/index_expr.h" -#include "paddle/ap/include/index_expr/index_tuple_expr.h" diff --git a/paddle/ap/src/index_expr/valid_index_expr_builder.cc b/paddle/ap/src/index_expr/valid_index_expr_builder.cc deleted file mode 100644 index 6d2f5188ca61cd..00000000000000 --- a/paddle/ap/src/index_expr/valid_index_expr_builder.cc +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/ap/include/index_expr/valid_index_expr_builder.h" -#include "paddle/ap/include/adt/adt.h" -#include "paddle/ap/include/index_expr/index_expr.h" -#include "paddle/ap/include/index_expr/index_expr_util.h" -#include "paddle/ap/include/index_expr/index_tuple_expr.h" -#include "paddle/ap/include/index_expr/slice.h" - -namespace ap::index_expr {} // namespace ap::index_expr diff --git a/paddle/ap/src/paddle/pass/ap_kernel_define_helper.cc b/paddle/ap/src/paddle/pass/ap_kernel_define_helper.cc index fa2b370313d4b4..9d0c68ab83f052 100644 --- a/paddle/ap/src/paddle/pass/ap_kernel_define_helper.cc +++ b/paddle/ap/src/paddle/pass/ap_kernel_define_helper.cc @@ -19,7 +19,8 @@ #include "paddle/ap/include/code_gen/value_method_class.h" #include "paddle/ap/include/drr/drr_graph_descriptor.h" #include "paddle/ap/include/drr/drr_node_descriptor.h" -#include "paddle/ap/include/paddle/op_cuda_code_gen_impl.h" +#include "paddle/ap/include/paddle/pir_graph_descriptor.h" +#include "paddle/ap/include/paddle/pir_node_descriptor.h" #include "paddle/ap/include/paddle/pir_node_method_class.h" namespace cinn::dialect::ir { diff --git a/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc b/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc index df1e5cd717465c..f6b956f54c0593 100644 --- a/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc +++ b/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc @@ -37,12 +37,10 @@ #include "paddle/ap/include/drr/result_pattern_helper.h" #include "paddle/ap/include/drr/value.h" #include "paddle/ap/include/graph/graph_helper.h" -#include "paddle/ap/include/index_expr/valid_index_expr_builder.h" #include "paddle/ap/include/ir_match/graph_matcher.h" #include "paddle/ap/include/ir_match/ir_match_ctx.h" #include "paddle/ap/include/ir_match/op_match_ctx_method_class.h" #include "paddle/ap/include/ir_match/tensor_match_ctx_method_class.h" -#include "paddle/ap/include/paddle/indexed_ir_graph_util.h" #include "paddle/ap/include/paddle/pass/ap_drr_helper.h" #include "paddle/ap/include/paddle/pass/ap_kernel_define_helper.h" #include "paddle/ap/include/paddle/pass/ap_registry_helper.h" From f15a7660f5deb46b69e79c56931146ff3769437d Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 17 Feb 2025 19:23:23 +0800 Subject: [PATCH 03/43] Fix compiling error on CI. --- paddle/ap/include/axpr/anf_expr_util.h | 4 ++-- paddle/ap/include/axpr/attr_map.h | 1 + paddle/ap/include/axpr/bool_method_class.h | 1 - paddle/ap/include/axpr/builtin_class_instance.h | 5 +---- paddle/ap/include/axpr/builtin_symbol.h | 4 ++-- paddle/ap/include/axpr/cps_interpreter.h | 6 +++--- paddle/ap/include/axpr/data_value_method_class.h | 1 - paddle/ap/include/axpr/float_method_class.h | 1 - paddle/ap/include/axpr/mutable_list_method_class.h | 4 ++-- paddle/ap/src/axpr/builtin_functions.cc | 13 +++---------- .../ap/src/code_gen/code_gen_result_method_class.cc | 2 -- .../ap/src/code_module/code_module_method_class.cc | 2 -- paddle/ap/src/code_module/directory_method_class.cc | 2 -- .../ap/src/code_module/file_content_method_class.cc | 2 -- .../ap/src/code_module/func_declare_method_class.cc | 2 -- paddle/ap/src/code_module/package_method_class.cc | 2 -- paddle/ap/src/code_module/project_method_class.cc | 2 -- paddle/ap/src/code_module/soft_link_method_class.cc | 2 -- paddle/ap/src/drr/drr_ctx_method_class.cc | 2 -- .../ap/src/drr/native_ir_op_declare_method_class.cc | 2 -- paddle/ap/src/drr/native_ir_op_method_class.cc | 2 -- paddle/ap/src/drr/native_ir_value_method_class.cc | 2 -- .../drr/opt_packed_ir_op_declare_method_class.cc | 2 -- paddle/ap/src/drr/opt_packed_ir_op_method_class.cc | 2 -- .../ap/src/drr/packed_ir_op_declare_method_class.cc | 2 -- paddle/ap/src/drr/packed_ir_op_method_class.cc | 2 -- paddle/ap/src/drr/packed_ir_value_method_class.cc | 2 -- .../src/drr/res_ptn_op_pattern_ctx_method_class.cc | 2 -- .../drr/res_ptn_tensor_pattern_ctx_method_class.cc | 2 -- .../res_ptn_unbound_native_ir_op_method_class.cc | 2 -- .../res_ptn_unbound_packed_ir_op_method_class.cc | 2 -- .../ap/src/drr/result_pattern_ctx_method_class.cc | 2 -- .../ap/src/drr/source_pattern_ctx_method_class.cc | 2 -- .../src/drr/src_ptn_op_pattern_ctx_method_class.cc | 2 -- .../drr/src_ptn_tensor_pattern_ctx_method_class.cc | 2 -- .../src_ptn_unbound_native_ir_op_method_class.cc | 2 -- .../src_ptn_unbound_packed_ir_op_method_class.cc | 2 -- paddle/ap/src/drr/unbound_ir_value_method_class.cc | 2 -- .../drr/unbound_opt_packed_ir_op_method_class.cc | 2 -- .../src/drr/unbound_packed_ir_value_method_class.cc | 2 -- .../src/index_expr/index_expr_builtin_functions.cc | 1 - .../src/kernel_dispatch/device_ctx_method_class.cc | 2 -- .../ap/src/paddle/pass/ap_lower_fusion_op_pass.cc | 2 -- paddle/ap/src/paddle/pass/ap_registry_helper.cc | 2 -- paddle/ap/src/paddle/pass/ir_helper_method_class.cc | 2 -- paddle/ap/src/paddle/pir/attribute_method_class.cc | 2 -- .../ap/src/paddle/pir/pass_manager_method_class.cc | 2 -- paddle/ap/src/paddle/pir/pass_method_class.cc | 2 -- paddle/ap/src/paddle/pir/pir_method_class.cc | 2 -- paddle/ap/src/paddle/pir/program_method_class.cc | 2 -- .../ap/src/paddle/pir/shape_or_data_method_class.cc | 2 -- paddle/ap/src/paddle/pir/type_method_class.cc | 2 -- 52 files changed, 14 insertions(+), 109 deletions(-) diff --git a/paddle/ap/include/axpr/anf_expr_util.h b/paddle/ap/include/axpr/anf_expr_util.h index 5193d9af2c31c9..c9f3bec3b30f85 100644 --- a/paddle/ap/include/axpr/anf_expr_util.h +++ b/paddle/ap/include/axpr/anf_expr_util.h @@ -571,7 +571,7 @@ struct ParseJsonToAnfExprHelperCallAnfExpr { "should a valid atomic AnfExpr."); } std::vector> args; - for (int i = 1; i < j_obj.size(); ++i) { + for (size_t i = 1; i < j_obj.size(); ++i) { const auto& arg = j_obj.at(i); const auto& arg_expr = ConvertJsonToAnfExpr(arg); if (!arg_expr.HasOkValue()) { @@ -658,7 +658,7 @@ struct ParseJsonToAnfExprHelperLetAnfExpr { } std::vector> bindings; const auto& j_bindings = j_obj.at(1); - for (int i = 0; i < j_bindings.size(); ++i) { + for (size_t i = 0; i < j_bindings.size(); ++i) { const auto& binding = j_bindings.at(i); if (!binding.is_array()) { return JsonParseFailed(binding, diff --git a/paddle/ap/include/axpr/attr_map.h b/paddle/ap/include/axpr/attr_map.h index 9eaa103d73b008..38d4bcfed72340 100644 --- a/paddle/ap/include/axpr/attr_map.h +++ b/paddle/ap/include/axpr/attr_map.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "paddle/ap/include/axpr/adt.h" #include "paddle/ap/include/axpr/error.h" diff --git a/paddle/ap/include/axpr/bool_method_class.h b/paddle/ap/include/axpr/bool_method_class.h index c581d2fd980d62..928fff6781de42 100644 --- a/paddle/ap/include/axpr/bool_method_class.h +++ b/paddle/ap/include/axpr/bool_method_class.h @@ -77,7 +77,6 @@ struct BoolMethodClass { rhs); }, [&](const auto& impl) -> adt::Result { - using T = std::decay_t; return adt::errors::TypeError{ std::string() + "unsupported operand type(s) for " + ArithmeticOp::Name() + ": 'bool' and '" + diff --git a/paddle/ap/include/axpr/builtin_class_instance.h b/paddle/ap/include/axpr/builtin_class_instance.h index 596546066062af..2c8970fc8cb152 100644 --- a/paddle/ap/include/axpr/builtin_class_instance.h +++ b/paddle/ap/include/axpr/builtin_class_instance.h @@ -100,10 +100,7 @@ template ClassAttrs MakeBuiltinClass(const std::string& class_name, const VisitorT& Visitor) { AttrMap attr_map; - Visitor([&](const auto& name, const auto& val) { - using TestType = decltype(BuiltinFrameValImpl{val}); - attr_map->Set(name, val); - }); + Visitor([&](const auto& name, const auto& val) { attr_map->Set(name, val); }); adt::List>> empty_superclasses{}; return ClassAttrs{class_name, empty_superclasses, attr_map}; } diff --git a/paddle/ap/include/axpr/builtin_symbol.h b/paddle/ap/include/axpr/builtin_symbol.h index 25774d7947c87e..3b0c3353ed4626 100644 --- a/paddle/ap/include/axpr/builtin_symbol.h +++ b/paddle/ap/include/axpr/builtin_symbol.h @@ -140,7 +140,7 @@ struct Length : public std::monostate { PEXPR_FOR_EACH_UNARY_OP(DEFINE_UNARY_SYMBOL); -#undef DEFINE_UNARY_SYMBOL; +#undef DEFINE_UNARY_SYMBOL #define DEFINE_BINARY_SYMBOL(name, op) \ struct name : public std::monostate { \ @@ -152,7 +152,7 @@ PEXPR_FOR_EACH_UNARY_OP(DEFINE_UNARY_SYMBOL); PEXPR_FOR_EACH_BINARY_OP(DEFINE_BINARY_SYMBOL); -#undef DEFINE_BINARY_SYMBOL; +#undef DEFINE_BINARY_SYMBOL #define AXPR_FOR_EACH_SYMBOL_OP(_) \ PEXPR_FOR_EACH_BINARY_OP(_) \ diff --git a/paddle/ap/include/axpr/cps_interpreter.h b/paddle/ap/include/axpr/cps_interpreter.h index a458a766471cb3..52af21bc696dc1 100644 --- a/paddle/ap/include/axpr/cps_interpreter.h +++ b/paddle/ap/include/axpr/cps_interpreter.h @@ -426,7 +426,7 @@ class CpsInterpreter : public InterpreterBase { passed_args.insert(self_name); ADT_RETURN_IF_ERR(env->Set(self_name, self.value())); } - for (int pos_arg_idx = 0; pos_arg_idx < pos_args->size(); + for (size_t pos_arg_idx = 0; pos_arg_idx < pos_args->size(); ++pos_arg_idx, ++lambda_arg_idx) { const auto& arg_name = lambda->args.at(lambda_arg_idx).value(); passed_args.insert(arg_name); @@ -479,13 +479,13 @@ class CpsInterpreter : public InterpreterBase { ss << "() missing " << (lambda->args.size() - args.size()) << " required positional arguments: "; ss << "'" << lambda->args.at(args.size()).value() << "'"; - for (int i = args.size() + 1; i < lambda->args.size(); ++i) { + for (size_t i = args.size() + 1; i < lambda->args.size(); ++i) { ss << "and '" << lambda->args.at(i).value() << "'"; } return adt::errors::TypeError{ss.str()}; } } - for (int i = 0; i < args.size(); ++i) { + for (size_t i = 0; i < args.size(); ++i) { const auto& arg_name = lambda->args.at(i).value(); ADT_RETURN_IF_ERR(env->Set(arg_name, args.at(i))); } diff --git a/paddle/ap/include/axpr/data_value_method_class.h b/paddle/ap/include/axpr/data_value_method_class.h index 95aae9771526d6..e6c0d93a34fe18 100644 --- a/paddle/ap/include/axpr/data_value_method_class.h +++ b/paddle/ap/include/axpr/data_value_method_class.h @@ -150,7 +150,6 @@ adt::Result ConstructDataValue(const ValueT&, [](int64_t c) -> adt::Result { return DataValue{c}; }, [](const DataValue& c) -> adt::Result { return c; }, [&](const auto& impl) -> adt::Result { - using T = std::decay_t; return adt::errors::TypeError{ std::string() + "unsupported operand type for constructor of 'DataValue': '" + diff --git a/paddle/ap/include/axpr/float_method_class.h b/paddle/ap/include/axpr/float_method_class.h index fddb0045c271e4..cfcef07b6ceb84 100644 --- a/paddle/ap/include/axpr/float_method_class.h +++ b/paddle/ap/include/axpr/float_method_class.h @@ -75,7 +75,6 @@ struct FloatMethodClass { rhs); }, [&](const auto& impl) -> adt::Result { - using T = std::decay_t; return adt::errors::TypeError{std::string() + "unsupported operand type(s) for " + ArithmeticOp::Name() + ": 'int' and '" + diff --git a/paddle/ap/include/axpr/mutable_list_method_class.h b/paddle/ap/include/axpr/mutable_list_method_class.h index 2c0cac32b16d8c..5c18d3e236cacd 100644 --- a/paddle/ap/include/axpr/mutable_list_method_class.h +++ b/paddle/ap/include/axpr/mutable_list_method_class.h @@ -86,7 +86,7 @@ struct MethodClassImpl> { if (index < 0) { index += vec->size(); } - if (index >= 0 && index < vec->size()) { + if (index >= 0 && index < static_cast(vec->size())) { return vec->at(index); } return adt::errors::IndexError{"list index out of range"}; @@ -115,7 +115,7 @@ struct MethodClassImpl> { if (idx < 0) { idx += self_ptr->size(); } - ADT_CHECK(idx < self_ptr->size()) + ADT_CHECK(idx < static_cast(self_ptr->size())) << adt::errors::IndexError{"list index out of range"}; self_ptr->at(idx) = args.at(1); return adt::Nothing{}; diff --git a/paddle/ap/src/axpr/builtin_functions.cc b/paddle/ap/src/axpr/builtin_functions.cc index 70e629d91ce6c8..81db4841753eae 100644 --- a/paddle/ap/src/axpr/builtin_functions.cc +++ b/paddle/ap/src/axpr/builtin_functions.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once #include "paddle/ap/include/axpr/builtin_functions.h" #include #include @@ -155,18 +154,12 @@ adt::Result ReplaceOrTrimLeftComma( if (start == std::string::npos) { return false; } - if (start < 0) { - return false; - } if (start >= self.size()) { return false; } if (end == std::string::npos) { return false; } - if (end < 0) { - return false; - } if (end >= self.size()) { return false; } @@ -176,7 +169,7 @@ adt::Result ReplaceOrTrimLeftComma( if (self[start] != ',') { return false; } - for (int i = start + 1; i < end; ++i) { + for (size_t i = start + 1; i < end; ++i) { char ch = self[i]; if (ch == ' ') { continue; @@ -386,7 +379,7 @@ Result Zip(const axpr::Value&, } adt::List ret; ret->reserve(size.value()); - for (int i = 0; i < size.value(); ++i) { + for (size_t i = 0; i < size.value(); ++i) { adt::List tuple; tuple->reserve(args.size()); for (const auto& arg : args) { @@ -423,7 +416,7 @@ Result Reduce(axpr::InterpreterBase* interpreter, ADT_CHECK(start.has_value()); axpr::Value ret{init.value()}; const auto& f = args.at(0); - for (int i = start.value(); i < lst_size; ++i) { + for (size_t i = start.value(); i < lst_size; ++i) { ADT_LET_CONST_REF(elt, lst.at(i)); ADT_LET_CONST_REF( cur_reduced, diff --git a/paddle/ap/src/code_gen/code_gen_result_method_class.cc b/paddle/ap/src/code_gen/code_gen_result_method_class.cc index 3f5a2f67512053..a1b6939d7ce93c 100644 --- a/paddle/ap/src/code_gen/code_gen_result_method_class.cc +++ b/paddle/ap/src/code_gen/code_gen_result_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/code_gen/code_gen_result_method_class.h" namespace ap::code_gen { diff --git a/paddle/ap/src/code_module/code_module_method_class.cc b/paddle/ap/src/code_module/code_module_method_class.cc index dad863ca372389..154f27155678a0 100644 --- a/paddle/ap/src/code_module/code_module_method_class.cc +++ b/paddle/ap/src/code_module/code_module_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/code_module/code_module_method_class.h" namespace ap::code_module { diff --git a/paddle/ap/src/code_module/directory_method_class.cc b/paddle/ap/src/code_module/directory_method_class.cc index 9e87b343fac6c8..b9d990f1842be8 100644 --- a/paddle/ap/src/code_module/directory_method_class.cc +++ b/paddle/ap/src/code_module/directory_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/code_module/directory_method_class.h" namespace ap::code_module { diff --git a/paddle/ap/src/code_module/file_content_method_class.cc b/paddle/ap/src/code_module/file_content_method_class.cc index 9639650f74b743..09de528cfc6d71 100644 --- a/paddle/ap/src/code_module/file_content_method_class.cc +++ b/paddle/ap/src/code_module/file_content_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/code_module/file_content_method_class.h" namespace ap::code_module { diff --git a/paddle/ap/src/code_module/func_declare_method_class.cc b/paddle/ap/src/code_module/func_declare_method_class.cc index c8d0fb82dfca14..384583a9c70962 100644 --- a/paddle/ap/src/code_module/func_declare_method_class.cc +++ b/paddle/ap/src/code_module/func_declare_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/code_module/func_declare_method_class.h" namespace ap::code_module { diff --git a/paddle/ap/src/code_module/package_method_class.cc b/paddle/ap/src/code_module/package_method_class.cc index 9ffc7d0d41bed7..8e809a48e23241 100644 --- a/paddle/ap/src/code_module/package_method_class.cc +++ b/paddle/ap/src/code_module/package_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/code_module/package_method_class.h" namespace ap::code_module { diff --git a/paddle/ap/src/code_module/project_method_class.cc b/paddle/ap/src/code_module/project_method_class.cc index 20e0775b7fe935..cb94a7b0928850 100644 --- a/paddle/ap/src/code_module/project_method_class.cc +++ b/paddle/ap/src/code_module/project_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/code_module/project_method_class.h" namespace ap::code_module { diff --git a/paddle/ap/src/code_module/soft_link_method_class.cc b/paddle/ap/src/code_module/soft_link_method_class.cc index a470df17754859..4742a2ee318a56 100644 --- a/paddle/ap/src/code_module/soft_link_method_class.cc +++ b/paddle/ap/src/code_module/soft_link_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/code_module/soft_link_method_class.h" namespace ap::code_module { diff --git a/paddle/ap/src/drr/drr_ctx_method_class.cc b/paddle/ap/src/drr/drr_ctx_method_class.cc index 9c2591fcb3733b..a9aaf8a671304e 100644 --- a/paddle/ap/src/drr/drr_ctx_method_class.cc +++ b/paddle/ap/src/drr/drr_ctx_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/drr_ctx_method_class.h" #include "paddle/ap/include/axpr/callable_helper.h" diff --git a/paddle/ap/src/drr/native_ir_op_declare_method_class.cc b/paddle/ap/src/drr/native_ir_op_declare_method_class.cc index 3db0b27efa1845..5b1818cfe3dc14 100644 --- a/paddle/ap/src/drr/native_ir_op_declare_method_class.cc +++ b/paddle/ap/src/drr/native_ir_op_declare_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/native_ir_op_declare_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/native_ir_op_method_class.cc b/paddle/ap/src/drr/native_ir_op_method_class.cc index b24d4839c7104b..cbee15b51921c5 100644 --- a/paddle/ap/src/drr/native_ir_op_method_class.cc +++ b/paddle/ap/src/drr/native_ir_op_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/native_ir_op_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/native_ir_value_method_class.cc b/paddle/ap/src/drr/native_ir_value_method_class.cc index 215183c8283c1f..8e25fcd1c9db3d 100644 --- a/paddle/ap/src/drr/native_ir_value_method_class.cc +++ b/paddle/ap/src/drr/native_ir_value_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/native_ir_value_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/opt_packed_ir_op_declare_method_class.cc b/paddle/ap/src/drr/opt_packed_ir_op_declare_method_class.cc index 364d7542a88ac0..41eb1d80223ed2 100644 --- a/paddle/ap/src/drr/opt_packed_ir_op_declare_method_class.cc +++ b/paddle/ap/src/drr/opt_packed_ir_op_declare_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/opt_packed_ir_op_declare_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/opt_packed_ir_op_method_class.cc b/paddle/ap/src/drr/opt_packed_ir_op_method_class.cc index e5b2bb8a76a9cc..d76fbc91570346 100644 --- a/paddle/ap/src/drr/opt_packed_ir_op_method_class.cc +++ b/paddle/ap/src/drr/opt_packed_ir_op_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/opt_packed_ir_op_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/packed_ir_op_declare_method_class.cc b/paddle/ap/src/drr/packed_ir_op_declare_method_class.cc index 325fc9550989c7..e6cb7858762d8e 100644 --- a/paddle/ap/src/drr/packed_ir_op_declare_method_class.cc +++ b/paddle/ap/src/drr/packed_ir_op_declare_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/packed_ir_op_declare_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/packed_ir_op_method_class.cc b/paddle/ap/src/drr/packed_ir_op_method_class.cc index 3dbb966d4f110f..0dead004d35066 100644 --- a/paddle/ap/src/drr/packed_ir_op_method_class.cc +++ b/paddle/ap/src/drr/packed_ir_op_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/packed_ir_op_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/packed_ir_value_method_class.cc b/paddle/ap/src/drr/packed_ir_value_method_class.cc index c61fdec00380ea..ba720f45425441 100644 --- a/paddle/ap/src/drr/packed_ir_value_method_class.cc +++ b/paddle/ap/src/drr/packed_ir_value_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/packed_ir_value_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/res_ptn_op_pattern_ctx_method_class.cc b/paddle/ap/src/drr/res_ptn_op_pattern_ctx_method_class.cc index 466d89e3c76144..d2e29ddfd0707a 100644 --- a/paddle/ap/src/drr/res_ptn_op_pattern_ctx_method_class.cc +++ b/paddle/ap/src/drr/res_ptn_op_pattern_ctx_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/res_ptn_op_pattern_ctx_method_class.h" #include "paddle/ap/include/axpr/callable_helper.h" diff --git a/paddle/ap/src/drr/res_ptn_tensor_pattern_ctx_method_class.cc b/paddle/ap/src/drr/res_ptn_tensor_pattern_ctx_method_class.cc index 4eaeb8f9acd07d..ac5e5887eb91bd 100644 --- a/paddle/ap/src/drr/res_ptn_tensor_pattern_ctx_method_class.cc +++ b/paddle/ap/src/drr/res_ptn_tensor_pattern_ctx_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/res_ptn_tensor_pattern_ctx_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/res_ptn_unbound_native_ir_op_method_class.cc b/paddle/ap/src/drr/res_ptn_unbound_native_ir_op_method_class.cc index e87015d403dca4..93f009c4f3045d 100644 --- a/paddle/ap/src/drr/res_ptn_unbound_native_ir_op_method_class.cc +++ b/paddle/ap/src/drr/res_ptn_unbound_native_ir_op_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/res_ptn_unbound_native_ir_op_method_class.h" #include "paddle/ap/include/axpr/callable_helper.h" #include "paddle/ap/include/drr/drr_pass_type_helper.h" diff --git a/paddle/ap/src/drr/res_ptn_unbound_packed_ir_op_method_class.cc b/paddle/ap/src/drr/res_ptn_unbound_packed_ir_op_method_class.cc index 10954f7702e7ad..0e808ff3bd2568 100644 --- a/paddle/ap/src/drr/res_ptn_unbound_packed_ir_op_method_class.cc +++ b/paddle/ap/src/drr/res_ptn_unbound_packed_ir_op_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/res_ptn_unbound_packed_ir_op_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/result_pattern_ctx_method_class.cc b/paddle/ap/src/drr/result_pattern_ctx_method_class.cc index ad2c368fa0102e..189be3c11ad31b 100644 --- a/paddle/ap/src/drr/result_pattern_ctx_method_class.cc +++ b/paddle/ap/src/drr/result_pattern_ctx_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/result_pattern_ctx_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/source_pattern_ctx_method_class.cc b/paddle/ap/src/drr/source_pattern_ctx_method_class.cc index 010d603a0de84b..9ba1b8e0bb9a60 100644 --- a/paddle/ap/src/drr/source_pattern_ctx_method_class.cc +++ b/paddle/ap/src/drr/source_pattern_ctx_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/source_pattern_ctx_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/src_ptn_op_pattern_ctx_method_class.cc b/paddle/ap/src/drr/src_ptn_op_pattern_ctx_method_class.cc index 5abca213f461b8..8f1a414b55759c 100644 --- a/paddle/ap/src/drr/src_ptn_op_pattern_ctx_method_class.cc +++ b/paddle/ap/src/drr/src_ptn_op_pattern_ctx_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/src_ptn_op_pattern_ctx_method_class.h" #include #include "paddle/ap/include/drr/drr_pass_type_helper.h" diff --git a/paddle/ap/src/drr/src_ptn_tensor_pattern_ctx_method_class.cc b/paddle/ap/src/drr/src_ptn_tensor_pattern_ctx_method_class.cc index 3837c9eb37942e..f1d89ff894145f 100644 --- a/paddle/ap/src/drr/src_ptn_tensor_pattern_ctx_method_class.cc +++ b/paddle/ap/src/drr/src_ptn_tensor_pattern_ctx_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/src_ptn_tensor_pattern_ctx_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/src_ptn_unbound_native_ir_op_method_class.cc b/paddle/ap/src/drr/src_ptn_unbound_native_ir_op_method_class.cc index 87c62331f10601..e05d83b8a24e71 100644 --- a/paddle/ap/src/drr/src_ptn_unbound_native_ir_op_method_class.cc +++ b/paddle/ap/src/drr/src_ptn_unbound_native_ir_op_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/src_ptn_unbound_native_ir_op_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/src_ptn_unbound_packed_ir_op_method_class.cc b/paddle/ap/src/drr/src_ptn_unbound_packed_ir_op_method_class.cc index a9382aa62b6c7c..e5729c2085c396 100644 --- a/paddle/ap/src/drr/src_ptn_unbound_packed_ir_op_method_class.cc +++ b/paddle/ap/src/drr/src_ptn_unbound_packed_ir_op_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/src_ptn_unbound_packed_ir_op_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/unbound_ir_value_method_class.cc b/paddle/ap/src/drr/unbound_ir_value_method_class.cc index a439916ed2bd7b..7cd49db170f311 100644 --- a/paddle/ap/src/drr/unbound_ir_value_method_class.cc +++ b/paddle/ap/src/drr/unbound_ir_value_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/unbound_ir_value_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/unbound_opt_packed_ir_op_method_class.cc b/paddle/ap/src/drr/unbound_opt_packed_ir_op_method_class.cc index 4bde3c896544d0..3d80135f55b95d 100644 --- a/paddle/ap/src/drr/unbound_opt_packed_ir_op_method_class.cc +++ b/paddle/ap/src/drr/unbound_opt_packed_ir_op_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/unbound_opt_packed_ir_op_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/drr/unbound_packed_ir_value_method_class.cc b/paddle/ap/src/drr/unbound_packed_ir_value_method_class.cc index 288f51586b423e..d0f958989984b2 100644 --- a/paddle/ap/src/drr/unbound_packed_ir_value_method_class.cc +++ b/paddle/ap/src/drr/unbound_packed_ir_value_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/drr/unbound_packed_ir_value_method_class.h" namespace ap::drr { diff --git a/paddle/ap/src/index_expr/index_expr_builtin_functions.cc b/paddle/ap/src/index_expr/index_expr_builtin_functions.cc index eef5d8227608dc..e8852feff11d14 100644 --- a/paddle/ap/src/index_expr/index_expr_builtin_functions.cc +++ b/paddle/ap/src/index_expr/index_expr_builtin_functions.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once #include "paddle/ap/include/axpr/builtin_functions.h" #include "paddle/ap/include/index_expr/index_expr_util.h" #include "paddle/ap/include/index_expr/valid_index_expr_builder.h" diff --git a/paddle/ap/src/kernel_dispatch/device_ctx_method_class.cc b/paddle/ap/src/kernel_dispatch/device_ctx_method_class.cc index 0f39d6cd3e015a..7c3d98ba470d78 100644 --- a/paddle/ap/src/kernel_dispatch/device_ctx_method_class.cc +++ b/paddle/ap/src/kernel_dispatch/device_ctx_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/kernel_dispatch/device_ctx_method_class.h" #include "paddle/ap/include/axpr/value.h" diff --git a/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc b/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc index df1e5cd717465c..bda279b016859c 100644 --- a/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc +++ b/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/paddle/pass/ap_lower_fusion_op_pass.h" #include "paddle/ap/include/memory/circlable_ref_list_base.h" diff --git a/paddle/ap/src/paddle/pass/ap_registry_helper.cc b/paddle/ap/src/paddle/pass/ap_registry_helper.cc index fccbf3cdf5f336..97c28316454c45 100644 --- a/paddle/ap/src/paddle/pass/ap_registry_helper.cc +++ b/paddle/ap/src/paddle/pass/ap_registry_helper.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/paddle/pass/ap_registry_helper.h" #include "paddle/ap/include/registry/registry_mgr.h" diff --git a/paddle/ap/src/paddle/pass/ir_helper_method_class.cc b/paddle/ap/src/paddle/pass/ir_helper_method_class.cc index c23d5100e9d275..f27aca13262a80 100644 --- a/paddle/ap/src/paddle/pass/ir_helper_method_class.cc +++ b/paddle/ap/src/paddle/pass/ir_helper_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/paddle/pass/ir_helper_method_class.h" #include "paddle/ap/include/axpr/module_mgr.h" #include "paddle/ap/include/axpr/to_string.h" diff --git a/paddle/ap/src/paddle/pir/attribute_method_class.cc b/paddle/ap/src/paddle/pir/attribute_method_class.cc index 4478a6964203cf..c38e0b4e040f1d 100644 --- a/paddle/ap/src/paddle/pir/attribute_method_class.cc +++ b/paddle/ap/src/paddle/pir/attribute_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/paddle/pir/attribute_method_class.h" #include "paddle/ap/include/axpr/abstract_list.h" #include "paddle/ap/include/axpr/callable_helper.h" diff --git a/paddle/ap/src/paddle/pir/pass_manager_method_class.cc b/paddle/ap/src/paddle/pir/pass_manager_method_class.cc index f915b218da6cd5..d89fd3589987fa 100644 --- a/paddle/ap/src/paddle/pir/pass_manager_method_class.cc +++ b/paddle/ap/src/paddle/pir/pass_manager_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/paddle/pir/pass_manager_method_class.h" namespace ap::paddle { diff --git a/paddle/ap/src/paddle/pir/pass_method_class.cc b/paddle/ap/src/paddle/pir/pass_method_class.cc index a3718da95e6f0c..e8a279de8ab8b6 100644 --- a/paddle/ap/src/paddle/pir/pass_method_class.cc +++ b/paddle/ap/src/paddle/pir/pass_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/paddle/pir/pass_method_class.h" namespace ap::paddle { diff --git a/paddle/ap/src/paddle/pir/pir_method_class.cc b/paddle/ap/src/paddle/pir/pir_method_class.cc index f1d1b653c61f6c..0c479d7f609359 100644 --- a/paddle/ap/src/paddle/pir/pir_method_class.cc +++ b/paddle/ap/src/paddle/pir/pir_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/paddle/pir/pir_method_class.h" #include "paddle/ap/include/axpr/module_mgr.h" #include "paddle/ap/include/paddle/pir_node.h" diff --git a/paddle/ap/src/paddle/pir/program_method_class.cc b/paddle/ap/src/paddle/pir/program_method_class.cc index 4c5276bbb5a5d9..0c89e01b09eabf 100644 --- a/paddle/ap/src/paddle/pir/program_method_class.cc +++ b/paddle/ap/src/paddle/pir/program_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/paddle/pir/program_method_class.h" #include "paddle/ap/include/axpr/dim_expr_method_class.h" #include "paddle/ap/include/paddle/pir/attribute_method_class.h" diff --git a/paddle/ap/src/paddle/pir/shape_or_data_method_class.cc b/paddle/ap/src/paddle/pir/shape_or_data_method_class.cc index 10b75ba7e8396a..d4debbff0af3f8 100644 --- a/paddle/ap/src/paddle/pir/shape_or_data_method_class.cc +++ b/paddle/ap/src/paddle/pir/shape_or_data_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/axpr/callable_helper.h" #include "paddle/ap/include/axpr/data_type_util.h" #include "paddle/ap/include/paddle/pir/type_adt_type_id.h" diff --git a/paddle/ap/src/paddle/pir/type_method_class.cc b/paddle/ap/src/paddle/pir/type_method_class.cc index 166a2df9ddb77b..00abf1bb756cf0 100644 --- a/paddle/ap/src/paddle/pir/type_method_class.cc +++ b/paddle/ap/src/paddle/pir/type_method_class.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/ap/include/paddle/pir/type_method_class.h" #include "paddle/ap/include/axpr/callable_helper.h" #include "paddle/ap/include/axpr/data_type_util.h" From 35196b4081721254a5b211e483065dbd334ee98c Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 19 Feb 2025 12:11:07 +0800 Subject: [PATCH 04/43] Change the log level. --- paddle/ap/include/ir_match/topo_match_ctx.h | 8 ++++---- .../ap/src/paddle/pass/ap_lower_fusion_op_pass.cc | 15 ++++++++------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/paddle/ap/include/ir_match/topo_match_ctx.h b/paddle/ap/include/ir_match/topo_match_ctx.h index 8a1c617e7f2ae3..dbdbc7aa144082 100644 --- a/paddle/ap/include/ir_match/topo_match_ctx.h +++ b/paddle/ap/include/ir_match/topo_match_ctx.h @@ -85,7 +85,7 @@ struct TopoMatchCtxImpl { adt::Result InitBigGraphNodes(const sg_node_t& sg_node, const std::list& val) { - VLOG(0) << "InitBigGraphNodes. sg_node: " + VLOG(4) << "InitBigGraphNodes. sg_node: " << graph::NodeDescriptor{}.DebugId(sg_node) << ", val:" << [&] { @@ -125,7 +125,7 @@ struct TopoMatchCtxImpl { "'val'"}; } auto* ptr = &this->sg_node2bg_nodes_[sg_node]; - VLOG(0) << "UpdateBigGraphNodes: sg_node: " + VLOG(4) << "UpdateBigGraphNodes: sg_node: " << graph::NodeDescriptor{}.DebugId(sg_node) << ", old_val:" << [&] { @@ -135,7 +135,7 @@ struct TopoMatchCtxImpl { } return ss.str(); }(); - VLOG(0) << "UpdateBigGraphNodes: sg_node: " + VLOG(4) << "UpdateBigGraphNodes: sg_node: " << graph::NodeDescriptor{}.DebugId(sg_node) << ", arg_val:" << [&] { @@ -152,7 +152,7 @@ struct TopoMatchCtxImpl { lhs_iter = ptr->erase(lhs_iter); } } - VLOG(0) << "UpdateBigGraphNodes: sg_node: " + VLOG(4) << "UpdateBigGraphNodes: sg_node: " << graph::NodeDescriptor{}.DebugId(sg_node) << ", new_val: " << [&] { diff --git a/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc b/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc index 3b65ba5349f56d..1c558379bd39e4 100644 --- a/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc +++ b/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc @@ -305,7 +305,7 @@ struct ApRewriter { pir::Operation* op, pir::PatternRewriter* rewriter) const { ADT_CHECK(ctx_.drr_ctx->pass_name.has_value()); - LOG(ERROR) << "drr: " << ctx_.drr_ctx->pass_name.value() << " matched."; + VLOG(0) << "drr: " << ctx_.drr_ctx->pass_name.value() << " matched."; return RewriteByResultPattern(match_ctx, op, rewriter); } @@ -2296,7 +2296,7 @@ class NativeOpAnchorApLowerFusionOpPattern : public pir::RewritePattern { return false; } ADT_CHECK(ctx_.drr_ctx->pass_name.has_value()); - LOG(ERROR) << "drr: " << ctx_.drr_ctx->pass_name.value() << " matched."; + VLOG(0) << "drr: " << ctx_.drr_ctx->pass_name.value() << " matched."; ADT_LET_CONST_REF( success, ap_rewriter_.Rewrite(opt_match_ctx.value(), op, rewriter)); if (success) { @@ -2410,10 +2410,11 @@ class DefaultAnchorApLowerFusionOpPattern : public pir::RewritePattern { } const auto& ret = this->TryMatchAndRewrite(op, &rewriter); if (ret.HasError()) { - LOG(ERROR) << "\nTraceback (most recent call last):\n" - << ret.GetError().CallStackToString() << "\n" - << ret.GetError().class_name() << ": " << ret.GetError().msg() - << "\npass_name: " << ctx_.drr_ctx->pass_name.value(); + VLOG(0) << "drr: " << ctx_.drr_ctx->pass_name.value() << " mismatched."; + VLOG(6) << "\nTraceback (most recent call last):\n" + << ret.GetError().CallStackToString() << "\n" + << ret.GetError().class_name() << ": " << ret.GetError().msg() + << "\npass_name: " << ctx_.drr_ctx->pass_name.value(); return false; } bool success = ret.GetOkValue(); @@ -2430,7 +2431,7 @@ class DefaultAnchorApLowerFusionOpPattern : public pir::RewritePattern { return false; } ADT_CHECK(ctx_.drr_ctx->pass_name.has_value()); - LOG(ERROR) << "drr: " << ctx_.drr_ctx->pass_name.value() << " matched."; + VLOG(0) << "drr: " << ctx_.drr_ctx->pass_name.value() << " matched."; ADT_LET_CONST_REF( success, ap_rewriter_.Rewrite(opt_match_ctx.value(), op, rewriter)); if (success) { From 639a8b3574f11dc1f22ea8555d7269ed4ca6c327 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 19 Feb 2025 12:25:34 +0800 Subject: [PATCH 05/43] Rename ap_lower_fusion_op_pass to ap_generic_drr_pass. --- ...fusion_op_pass.h => ap_generic_drr_pass.h} | 5 +- .../paddle/pass/ir_helper_method_class.h | 2 +- ...sion_op_pass.cc => ap_generic_drr_pass.cc} | 104 +++++++++--------- .../operator/transforms/add_cinn_pass.cc | 15 ++- 4 files changed, 61 insertions(+), 65 deletions(-) rename paddle/ap/include/paddle/pass/{ap_lower_fusion_op_pass.h => ap_generic_drr_pass.h} (90%) rename paddle/ap/src/paddle/pass/{ap_lower_fusion_op_pass.cc => ap_generic_drr_pass.cc} (97%) diff --git a/paddle/ap/include/paddle/pass/ap_lower_fusion_op_pass.h b/paddle/ap/include/paddle/pass/ap_generic_drr_pass.h similarity index 90% rename from paddle/ap/include/paddle/pass/ap_lower_fusion_op_pass.h rename to paddle/ap/include/paddle/pass/ap_generic_drr_pass.h index 1254056ca335c7..9a5667c21521ba 100644 --- a/paddle/ap/include/paddle/pass/ap_lower_fusion_op_pass.h +++ b/paddle/ap/include/paddle/pass/ap_generic_drr_pass.h @@ -34,10 +34,9 @@ namespace cinn { namespace dialect { namespace ir { -std::optional> -CreateApLowerFusionOpAbstractDrrPass( +std::optional> CreateApGenericAbstractDrrPass( const std::weak_ptr& circlable_ref_list); -std::optional> CreateApLowerFusionOpClassicDrrPass( +std::optional> CreateApGenericClassicDrrPass( const std::weak_ptr& circlable_ref_list); std::optional> CreateAccessTopoDrrPass( diff --git a/paddle/ap/include/paddle/pass/ir_helper_method_class.h b/paddle/ap/include/paddle/pass/ir_helper_method_class.h index 3d63ef993944c8..fc234cc74a654b 100644 --- a/paddle/ap/include/paddle/pass/ir_helper_method_class.h +++ b/paddle/ap/include/paddle/pass/ir_helper_method_class.h @@ -19,7 +19,7 @@ #include "paddle/ap/include/axpr/lambda_expr_builder.h" #include "paddle/ap/include/drr/drr_value_helper.h" #include "paddle/ap/include/paddle/pass/ap_drr_helper.h" -#include "paddle/ap/include/paddle/pass/ap_lower_fusion_op_pass.h" +#include "paddle/ap/include/paddle/pass/ap_generic_drr_pass.h" #include "paddle/ap/include/paddle/pass/ir_helper.h" #include "paddle/ap/include/paddle/pir/op_dialect.h" #include "paddle/ap/include/paddle/pir/packed_ir_op_inner_source_pattern_helper.h" diff --git a/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc b/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc similarity index 97% rename from paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc rename to paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc index 1c558379bd39e4..61d71e7a517da2 100644 --- a/paddle/ap/src/paddle/pass/ap_lower_fusion_op_pass.cc +++ b/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ap/include/paddle/pass/ap_lower_fusion_op_pass.h" +#include "paddle/ap/include/paddle/pass/ap_generic_drr_pass.h" #include "paddle/ap/include/memory/circlable_ref_list_base.h" #include "paddle/ap/include/adt/topo_walker.h" @@ -223,7 +223,7 @@ class NaiveDrrCtxProvider : public DrrCtxProvider { } }; -struct ApLowerFusionOpPatternCtx { +struct ApGenericDrrPatternCtx { std::shared_ptr drr_ctx_provider_; DrrCtx drr_ctx; std::vector res_ptn_outputs; @@ -236,7 +236,7 @@ struct ApLowerFusionOpPatternCtx { return drr_ctx_provider_; } - static adt::Result MakeFromDrrCtx( + static adt::Result MakeFromDrrCtx( const DrrCtx& drr_ctx, std::optional steps_limit, const std::shared_ptr& drr_ctx_provider) { @@ -253,13 +253,13 @@ struct ApLowerFusionOpPatternCtx { } ADT_LET_CONST_REF(anchor_op_name, GetAnchorOpName(opt_native_op_anchor, default_anchor)); - return ApLowerFusionOpPatternCtx{drr_ctx_provider, - drr_ctx, - res_ptn_outputs, - default_anchor, - opt_native_op_anchor, - anchor_op_name, - steps_limit}; + return ApGenericDrrPatternCtx{drr_ctx_provider, + drr_ctx, + res_ptn_outputs, + default_anchor, + opt_native_op_anchor, + anchor_op_name, + steps_limit}; } static adt::Result GetAnchorOpName( @@ -289,12 +289,12 @@ struct ApLowerFusionOpPatternCtx { }; struct ApRewriter { - ApLowerFusionOpPatternCtx ctx_; + ApGenericDrrPatternCtx ctx_; adt::Result> (*Match_)(const DrrCtx&, pir::Operation* op); mutable ApDrrHelper ap_drr_helper_; - ApRewriter(const ApLowerFusionOpPatternCtx& ctx, + ApRewriter(const ApGenericDrrPatternCtx& ctx, adt::Result> (*Match)( const DrrCtx&, pir::Operation* op)) : ctx_(ctx), @@ -1644,15 +1644,15 @@ struct ConstraintApplier { } }; -struct NativeOpAnchorApLowerFusionOpPatternMatcher { - const ApLowerFusionOpPatternCtx& ctx_; +struct NativeOpAnchorApGenericDrrPatternMatcher { + const ApGenericDrrPatternCtx& ctx_; - using Self = NativeOpAnchorApLowerFusionOpPatternMatcher; + using Self = NativeOpAnchorApGenericDrrPatternMatcher; static adt::Result> Match(const DrrCtx& drr_ctx, pir::Operation* op) { ADT_LET_CONST_REF(pattern_ctx, - ApLowerFusionOpPatternCtx::MakeFromDrrCtx( + ApGenericDrrPatternCtx::MakeFromDrrCtx( drr_ctx, /*steps_limit=*/std::nullopt, std::make_shared(drr_ctx))); @@ -2251,20 +2251,19 @@ struct OpEraseHelepr { } }; -class NativeOpAnchorApLowerFusionOpPattern : public pir::RewritePattern { +class NativeOpAnchorApGenericDrrPattern : public pir::RewritePattern { private: - ApLowerFusionOpPatternCtx ctx_; + ApGenericDrrPatternCtx ctx_; ApRewriter ap_rewriter_; mutable std::size_t times_; public: - NativeOpAnchorApLowerFusionOpPattern(pir::IrContext* ir_context, - const ApLowerFusionOpPatternCtx& ctx) + NativeOpAnchorApGenericDrrPattern(pir::IrContext* ir_context, + const ApGenericDrrPatternCtx& ctx) : pir::RewritePattern(ctx.anchor_op_name, 1, ir_context, {}), ctx_(ctx), times_(0), - ap_rewriter_(ctx, &NativeOpAnchorApLowerFusionOpPatternMatcher::Match) { - } + ap_rewriter_(ctx, &NativeOpAnchorApGenericDrrPatternMatcher::Match) {} bool MatchAndRewrite( pir::Operation* op, @@ -2310,19 +2309,19 @@ class NativeOpAnchorApLowerFusionOpPattern : public pir::RewritePattern { adt::Result> GetMatchCtx( pir::Operation* op) const { - return NativeOpAnchorApLowerFusionOpPatternMatcher{ctx_}.GetMatchCtx(op); + return NativeOpAnchorApGenericDrrPatternMatcher{ctx_}.GetMatchCtx(op); } }; -struct DefaultAnchorApLowerFusionOpPatternMatcher { - const ApLowerFusionOpPatternCtx& ctx_; +struct DefaultAnchorApGenericDrrPatternMatcher { + const ApGenericDrrPatternCtx& ctx_; - using Self = DefaultAnchorApLowerFusionOpPatternMatcher; + using Self = DefaultAnchorApGenericDrrPatternMatcher; static adt::Result> Match(const DrrCtx& drr_ctx, pir::Operation* op) { ADT_LET_CONST_REF(pattern_ctx, - ApLowerFusionOpPatternCtx::MakeFromDrrCtx( + ApGenericDrrPatternCtx::MakeFromDrrCtx( drr_ctx, /*times_step=*/std::nullopt, std::make_shared(drr_ctx))); @@ -2386,19 +2385,19 @@ struct DefaultAnchorApLowerFusionOpPatternMatcher { } }; -class DefaultAnchorApLowerFusionOpPattern : public pir::RewritePattern { +class DefaultAnchorApGenericDrrPattern : public pir::RewritePattern { private: - ApLowerFusionOpPatternCtx ctx_; + ApGenericDrrPatternCtx ctx_; mutable std::size_t times_; ApRewriter ap_rewriter_; public: - DefaultAnchorApLowerFusionOpPattern(pir::IrContext* ir_context, - const ApLowerFusionOpPatternCtx& ctx) + DefaultAnchorApGenericDrrPattern(pir::IrContext* ir_context, + const ApGenericDrrPatternCtx& ctx) : pir::RewritePattern(ctx.anchor_op_name, 1, ir_context, {}), ctx_(ctx), times_(0), - ap_rewriter_(ctx, &DefaultAnchorApLowerFusionOpPatternMatcher::Match) {} + ap_rewriter_(ctx, &DefaultAnchorApGenericDrrPatternMatcher::Match) {} bool MatchAndRewrite( pir::Operation* op, @@ -2445,17 +2444,17 @@ class DefaultAnchorApLowerFusionOpPattern : public pir::RewritePattern { adt::Result> GetMatchCtx( pir::Operation* op) const { - return DefaultAnchorApLowerFusionOpPatternMatcher{ctx_}.GetMatchCtx(op); + return DefaultAnchorApGenericDrrPatternMatcher{ctx_}.GetMatchCtx(op); } }; -class ApLowerFusionOpPass : public pir::PatternRewritePass { +class ApGenericDrrPass : public pir::PatternRewritePass { private: std::shared_ptr drr_ctx_provider_; std::optional steps_limit_; public: - explicit ApLowerFusionOpPass( + explicit ApGenericDrrPass( const std::shared_ptr& drr_ctx_provider, const std::string& name, std::optional steps_limit) @@ -2480,13 +2479,13 @@ class ApLowerFusionOpPass : public pir::PatternRewritePass { pir::IrContext* context) { auto AddFusionOpPattern = [&](const auto& drr_ctx) -> adt::Result { ADT_LET_CONST_REF(pattern_ctx, - ApLowerFusionOpPatternCtx::MakeFromDrrCtx( + ApGenericDrrPatternCtx::MakeFromDrrCtx( drr_ctx, steps_limit_, drr_ctx_provider_)); if (pattern_ctx.native_op_anchor.has_value()) { - ps->Add(std::make_unique( + ps->Add(std::make_unique( context, pattern_ctx)); } else { - ps->Add(std::make_unique( + ps->Add(std::make_unique( context, pattern_ctx)); } return adt::Ok{}; @@ -3103,8 +3102,7 @@ std::optional GetRegistrySingleton() { } // namespace -std::optional> -CreateApLowerFusionOpAbstractDrrPass( +std::optional> CreateApGenericAbstractDrrPass( const std::weak_ptr& circlable_ref_list) { if (!GetRegistrySingleton().has_value()) { return std::nullopt; @@ -3123,13 +3121,13 @@ CreateApLowerFusionOpAbstractDrrPass( return std::nullopt; } std::unique_ptr<::pir::Pass> pass = - std::make_unique(drr_ctx_provider, - /*name=*/"abstract", - /*steps_limit=*/std::nullopt); + std::make_unique(drr_ctx_provider, + /*name=*/"abstract", + /*steps_limit=*/std::nullopt); return std::move(pass); } -std::optional> CreateApLowerFusionOpClassicDrrPass( +std::optional> CreateApGenericClassicDrrPass( const std::weak_ptr& circlable_ref_list) { if (!GetRegistrySingleton().has_value()) { return std::nullopt; @@ -3148,9 +3146,9 @@ std::optional> CreateApLowerFusionOpClassicDrrPass( return std::nullopt; } std::unique_ptr<::pir::Pass> pass = - std::make_unique(drr_ctx_provider, - /*name=*/"classic", - /*steps_limit=*/std::nullopt); + std::make_unique(drr_ctx_provider, + /*name=*/"classic", + /*steps_limit=*/std::nullopt); return std::move(pass); } @@ -3175,9 +3173,9 @@ std::optional> CreateAccessTopoDrrPass( return std::nullopt; } std::unique_ptr<::pir::Pass> pass = - std::make_unique(drr_ctx_provider, - /*name=*/"tag_access_topo", - /*steps_limit=*/steps_limit); + std::make_unique(drr_ctx_provider, + /*name=*/"tag_access_topo", + /*steps_limit=*/steps_limit); return std::move(pass); } @@ -3200,9 +3198,9 @@ std::optional> CreateCustomAccessTopoDrrPass( return std::nullopt; } std::unique_ptr<::pir::Pass> pass = - std::make_unique(drr_ctx_provider, - /*name=*/"custom_access_topo", - /*steps_limit=*/steps_limit); + std::make_unique(drr_ctx_provider, + /*name=*/"custom_access_topo", + /*steps_limit=*/steps_limit); return std::move(pass); } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index 6ad748ad06b1b8..0bdee5861006df 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -28,7 +28,7 @@ #include "paddle/pir/include/pass/pass_manager.h" #include "paddle/ap/include/memory/guard.h" -#include "paddle/ap/include/paddle/pass/ap_lower_fusion_op_pass.h" +#include "paddle/ap/include/paddle/pass/ap_generic_drr_pass.h" #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/operator/transforms/accuracy_check_pass.h" @@ -228,24 +228,23 @@ void ApplyCinnLowerPass( } if (FLAGS_enable_ap) { ap::memory::Guard guard{}; - if (auto pass = - CreateApLowerFusionOpClassicDrrPass(guard.circlable_ref_list())) { + if (auto pass = CreateApGenericClassicDrrPass(guard.circlable_ref_list())) { pass_manager->AddPass(std::move(pass.value())); pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); - pir::IrPrinter(LOG(ERROR) << "before ApLowerFusionOpClassicDrrPass:\n") + pir::IrPrinter(LOG(ERROR) << "before ApGenericClassicDrrPass:\n") .PrintProgram(program); pass_manager->Run(program); - pir::IrPrinter(LOG(ERROR) << "after ApLowerFusionOpClassicDrrPass:\n") + pir::IrPrinter(LOG(ERROR) << "after ApGenericClassicDrrPass:\n") .PrintProgram(program); } if (auto pass = - CreateApLowerFusionOpAbstractDrrPass(guard.circlable_ref_list())) { + CreateApGenericAbstractDrrPass(guard.circlable_ref_list())) { pass_manager->AddPass(std::move(pass.value())); pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); - pir::IrPrinter(LOG(ERROR) << "before ApLowerFusionOpAbstractDrrPass:\n") + pir::IrPrinter(LOG(ERROR) << "before ApGenericAbstractDrrPass:\n") .PrintProgram(program); pass_manager->Run(program); - pir::IrPrinter(LOG(ERROR) << "after ApLowerFusionOpAbstractDrrPass:\n") + pir::IrPrinter(LOG(ERROR) << "after ApGenericAbstractDrrPass:\n") .PrintProgram(program); pass_manager = CreatePassManager(); pass_manager->AddPass(cinn::dialect::ir::CreateFusionFallbackPass()); From 1e0a138a80f27fff6841f21d77e4a26698650903 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 19 Feb 2025 13:45:43 +0800 Subject: [PATCH 06/43] Rename ap_unary -> ap_variadic, ApUnary -> ApVariadic. --- ...ap_unary_kernel.h => ap_variadic_kernel.h} | 2 +- .../ap/src/paddle/pass/ap_generic_drr_pass.cc | 4 +- ..._unary_kernel.cc => ap_variadic_kernel.cc} | 4 +- paddle/phi/infermeta/multiary.cc | 18 ++++---- paddle/phi/infermeta/multiary.h | 16 +++---- .../{ap_unary.cu => ap_variadic_kernel.cu} | 42 +++++++++---------- paddle/phi/ops/yaml/ops.yaml | 6 +-- 7 files changed, 46 insertions(+), 46 deletions(-) rename paddle/ap/include/kernel_dispatch/{ap_unary_kernel.h => ap_variadic_kernel.h} (96%) rename paddle/ap/src/paddle/phi/{ap_unary_kernel.cc => ap_variadic_kernel.cc} (99%) rename paddle/phi/kernels/gpu/{ap_unary.cu => ap_variadic_kernel.cu} (68%) diff --git a/paddle/ap/include/kernel_dispatch/ap_unary_kernel.h b/paddle/ap/include/kernel_dispatch/ap_variadic_kernel.h similarity index 96% rename from paddle/ap/include/kernel_dispatch/ap_unary_kernel.h rename to paddle/ap/include/kernel_dispatch/ap_variadic_kernel.h index cfa05c436c0b2d..96fb96b915b2a0 100644 --- a/paddle/ap/include/kernel_dispatch/ap_unary_kernel.h +++ b/paddle/ap/include/kernel_dispatch/ap_variadic_kernel.h @@ -28,7 +28,7 @@ class DenseTensor; namespace ap::kernel_dispatch { -adt::Result ApUnaryKernel( +adt::Result ApVariadicKernel( const DeviceCtx& device_ctx, const std::vector& xs, int num_outputs, diff --git a/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc b/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc index 61d71e7a517da2..eeefcba5003d0f 100644 --- a/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc +++ b/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc @@ -1140,14 +1140,14 @@ struct ApRewriter { const std::string& infer_meta_lambda_str, const std::string& kernel_dispatch_lambda_str, const std::string& kernel_dispatch_const_data_lambda_str) const { - auto ap_unary = rewriter->Build( + auto ap_variadic = rewriter->Build( input, num_outputs, code_gen_lambda_str, infer_meta_lambda_str, kernel_dispatch_lambda_str, kernel_dispatch_const_data_lambda_str); - return ap_unary.out(); + return ap_variadic.out(); } adt::Result> GetPackedOpOutputValues( diff --git a/paddle/ap/src/paddle/phi/ap_unary_kernel.cc b/paddle/ap/src/paddle/phi/ap_variadic_kernel.cc similarity index 99% rename from paddle/ap/src/paddle/phi/ap_unary_kernel.cc rename to paddle/ap/src/paddle/phi/ap_variadic_kernel.cc index 63a9bef64db131..14b44434e0b57d 100644 --- a/paddle/ap/src/paddle/phi/ap_unary_kernel.cc +++ b/paddle/ap/src/paddle/phi/ap_variadic_kernel.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ap/include/kernel_dispatch/ap_unary_kernel.h" +#include "paddle/ap/include/kernel_dispatch/ap_variadic_kernel.h" #include #include @@ -271,7 +271,7 @@ adt::Result> MakeMutableTensors( return ret; } -adt::Result ApUnaryKernel( +adt::Result ApVariadicKernel( const DeviceCtx& device_ctx, const std::vector& xs, int num_outputs, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index e90fc02c5c9fbd..f3fc8c39f32a49 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -469,18 +469,18 @@ void AddNInferMeta(const std::vector& x, out->set_dtype(x[0]->dtype()); } -void ApUnaryInferMeta(const std::vector& xs, - int num_outputs, - const std::string& code_module_lambda, - const std::string& infer_meta_lambda, - const std::string& kernel_dispatch_lambda, - const std::string& kernel_dispatch_const_data_lambda, - std::vector outs, - MetaConfig config) { +void ApVariadicInferMeta(const std::vector& xs, + int num_outputs, + const std::string& code_module_lambda, + const std::string& infer_meta_lambda, + const std::string& kernel_dispatch_lambda, + const std::string& kernel_dispatch_const_data_lambda, + std::vector outs, + MetaConfig config) { ApInferMetaHelper helper{}; const auto& ret = helper.InferMeta(infer_meta_lambda, &xs, &outs); PADDLE_ENFORCE(!ret.HasError(), - "ApUnaryInferMeta failed. \nTraceback (most recent call " + "ApVariadicInferMeta failed. \nTraceback (most recent call " "last):\n%s\n%s: %s. ", ret.GetError().CallStackToString(), ret.GetError().class_name(), diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 6a6fb6c4226200..78fc1dc54aec85 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -140,14 +140,14 @@ void AddNInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config = MetaConfig()); -void ApUnaryInferMeta(const std::vector& xs, - int num_outputs, - const std::string& code_module_lambda, - const std::string& infer_meta_lambda, - const std::string& kernel_dispatch_lambda, - const std::string& kernel_dispatch_const_data_lambda, - std::vector outs, - MetaConfig config = MetaConfig()); +void ApVariadicInferMeta(const std::vector& xs, + int num_outputs, + const std::string& code_module_lambda, + const std::string& infer_meta_lambda, + const std::string& kernel_dispatch_lambda, + const std::string& kernel_dispatch_const_data_lambda, + std::vector outs, + MetaConfig config = MetaConfig()); void AddNTensorArrayInferMeta(const std::vector& x, MetaTensor* out, diff --git a/paddle/phi/kernels/gpu/ap_unary.cu b/paddle/phi/kernels/gpu/ap_variadic_kernel.cu similarity index 68% rename from paddle/phi/kernels/gpu/ap_unary.cu rename to paddle/phi/kernels/gpu/ap_variadic_kernel.cu index be8c7fd7985360..cdd872017b641a 100644 --- a/paddle/phi/kernels/gpu/ap_unary.cu +++ b/paddle/phi/kernels/gpu/ap_variadic_kernel.cu @@ -25,20 +25,20 @@ #include "paddle/phi/kernels/impl/activation_grad_impl.h" #include "paddle/phi/kernels/impl/activation_impl.h" -#include "paddle/ap/include/kernel_dispatch/ap_unary_kernel.h" +#include "paddle/ap/include/kernel_dispatch/ap_variadic_kernel.h" #include "paddle/ap/include/paddle/phi/device_ctx.h" namespace phi { template -void ApUnaryKernel(const Context& dev_ctx, - const std::vector& xs, - int num_outputs, - const std::string& code_module_lambda, - const std::string& infer_meta_lambda, - const std::string& kernel_dispatch_lambda, - const std::string& kernel_dispatch_const_data_lambda, - std::vector outs) { +void ApVariadicKernel(const Context& dev_ctx, + const std::vector& xs, + int num_outputs, + const std::string& code_module_lambda, + const std::string& infer_meta_lambda, + const std::string& kernel_dispatch_lambda, + const std::string& kernel_dispatch_const_data_lambda, + std::vector outs) { PADDLE_ENFORCE_GT( xs.size(), 0, @@ -58,14 +58,14 @@ void ApUnaryKernel(const Context& dev_ctx, std::make_shared>(&dev_ctx); ap::kernel_dispatch::DeviceCtx ap_device_ctx{impl}; const auto& ret = - ap::kernel_dispatch::ApUnaryKernel(ap_device_ctx, - xs, - num_outputs, - code_module_lambda, - infer_meta_lambda, - kernel_dispatch_lambda, - kernel_dispatch_const_data_lambda, - outs); + ap::kernel_dispatch::ApVariadicKernel(ap_device_ctx, + xs, + num_outputs, + code_module_lambda, + infer_meta_lambda, + kernel_dispatch_lambda, + kernel_dispatch_const_data_lambda, + outs); PADDLE_ENFORCE( !ret.HasError(), "ap_kernel failed. \nTraceback (most recent call last):\n%s\n%s: %s. ", @@ -77,18 +77,18 @@ void ApUnaryKernel(const Context& dev_ctx, } // namespace phi #ifdef PADDLE_WITH_HIP -PD_REGISTER_KERNEL(ap_unary, +PD_REGISTER_KERNEL(ap_variadic, GPU, ALL_LAYOUT, - phi::ApUnaryKernel, + phi::ApVariadicKernel, float, double, phi::dtype::float16) {} #else -PD_REGISTER_KERNEL(ap_unary, +PD_REGISTER_KERNEL(ap_variadic, GPU, ALL_LAYOUT, - phi::ApUnaryKernel, + phi::ApVariadicKernel, float, double, phi::dtype::float16, diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index fd9deb0839c18a..10ea6660edd30c 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -269,13 +269,13 @@ traits : paddle::dialect::ForwardOnlyTrait interfaces : paddle::dialect::InferSymbolicShapeInterface -- op : ap_unary +- op : ap_variadic args : (Tensor[] xs, int num_outputs, str code_module_lambda, str infer_meta_lambda, str rnel_dispatch_lambda, str kernel_dispatch_const_data_lambda) output : Tensor[](out){num_outputs} infer_meta : - func : ApUnaryInferMeta + func : ApVariadicInferMeta kernel : - func : ap_unary + func : ap_variadic - op : apply_per_channel_scale args: (Tensor x, Tensor scales) From 01e2686b90478184dcfe0ae9aaca06ea52c2c4c2 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 20 Feb 2025 11:15:25 +0800 Subject: [PATCH 07/43] Fallback to cinn when ap fails, and disable fuse_gemm_epilogue when ap is enabled. --- .../ap/src/paddle/pass/ap_generic_drr_pass.cc | 3 +- .../operator/transforms/add_cinn_pass.cc | 82 ++++++++++--------- paddle/common/flags.cc | 21 +++++ 3 files changed, 68 insertions(+), 38 deletions(-) diff --git a/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc b/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc index eeefcba5003d0f..a628b3e6104ce0 100644 --- a/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc +++ b/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc @@ -2291,10 +2291,11 @@ class NativeOpAnchorApGenericDrrPattern : public pir::RewritePattern { adt::Result TryMatchAndRewrite(pir::Operation* op, pir::PatternRewriter* rewriter) const { ADT_LET_CONST_REF(opt_match_ctx, GetMatchCtx(op)); + ADT_CHECK(ctx_.drr_ctx->pass_name.has_value()); if (!opt_match_ctx.has_value()) { + VLOG(0) << "drr: " << ctx_.drr_ctx->pass_name.value() << " mismatched."; return false; } - ADT_CHECK(ctx_.drr_ctx->pass_name.has_value()); VLOG(0) << "drr: " << ctx_.drr_ctx->pass_name.value() << " matched."; ADT_LET_CONST_REF( success, ap_rewriter_.Rewrite(opt_match_ctx.value(), op, rewriter)); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index 0bdee5861006df..3b7ec4318dce26 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -74,6 +74,7 @@ COMMON_DECLARE_bool(enable_cinn_accuracy_check); COMMON_DECLARE_bool(enable_fuse_parallel_matmul_pass); COMMON_DECLARE_bool(enable_fusion_fallback); COMMON_DECLARE_bool(enable_ap); +COMMON_DECLARE_bool(ap_enable_classic_gemm_epilogue); COMMON_DECLARE_bool(logging_pir_py_code_dump_symbolic_dims); namespace cinn::dialect::ir { @@ -130,7 +131,9 @@ void ApplyPdToCinnPass( std::shared_ptr pass_manager = CreatePassManager(); pass_manager->AddPass(cinn::dialect::ir::CreateReduceAsToSumPass()); pass_manager->AddPass(cinn::dialect::ir::CreateReplaceZeroScaleToFullPass()); - pass_manager->AddPass(pir::CreateFusedGemmEpiloguePass()); + if (!FLAGS_enable_ap || FLAGS_ap_enable_classic_gemm_epilogue) { + pass_manager->AddPass(pir::CreateFusedGemmEpiloguePass()); + } if (FLAGS_enable_fuse_parallel_matmul_pass) { pass_manager->AddPass(cinn::dialect::ir::CreateFuseParallelMatmulPass()); } @@ -205,6 +208,32 @@ void ApplyDivideGroupOpToFusionOpPass( pass_manager->Run(program); } +void ApplyApGenericDrrPass( + ::pir::Program* program, + const std::function()>& + CreatePassManager) { + std::shared_ptr pass_manager = CreatePassManager(); + ap::memory::Guard guard{}; + if (auto pass = CreateApGenericClassicDrrPass(guard.circlable_ref_list())) { + pass_manager->AddPass(std::move(pass.value())); + pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); + pir::IrPrinter(LOG(ERROR) << "before ApGenericClassicDrrPass:\n") + .PrintProgram(program); + pass_manager->Run(program); + pir::IrPrinter(LOG(ERROR) << "after ApGenericClassicDrrPass:\n") + .PrintProgram(program); + } + if (auto pass = CreateApGenericAbstractDrrPass(guard.circlable_ref_list())) { + pass_manager->AddPass(std::move(pass.value())); + pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); + pir::IrPrinter(LOG(ERROR) << "before ApGenericAbstractDrrPass:\n") + .PrintProgram(program); + pass_manager->Run(program); + pir::IrPrinter(LOG(ERROR) << "after ApGenericAbstractDrrPass:\n") + .PrintProgram(program); + } +} + void ApplyCinnLowerPass( ::pir::Program* program, const std::function()>& @@ -227,44 +256,23 @@ void ApplyCinnLowerPass( pass_manager->AddPass(cinn::dialect::ir::CreateAccuracyCheckPass()); } if (FLAGS_enable_ap) { - ap::memory::Guard guard{}; - if (auto pass = CreateApGenericClassicDrrPass(guard.circlable_ref_list())) { - pass_manager->AddPass(std::move(pass.value())); - pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); - pir::IrPrinter(LOG(ERROR) << "before ApGenericClassicDrrPass:\n") - .PrintProgram(program); - pass_manager->Run(program); - pir::IrPrinter(LOG(ERROR) << "after ApGenericClassicDrrPass:\n") - .PrintProgram(program); - } - if (auto pass = - CreateApGenericAbstractDrrPass(guard.circlable_ref_list())) { - pass_manager->AddPass(std::move(pass.value())); - pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); - pir::IrPrinter(LOG(ERROR) << "before ApGenericAbstractDrrPass:\n") - .PrintProgram(program); - pass_manager->Run(program); - pir::IrPrinter(LOG(ERROR) << "after ApGenericAbstractDrrPass:\n") - .PrintProgram(program); - pass_manager = CreatePassManager(); - pass_manager->AddPass(cinn::dialect::ir::CreateFusionFallbackPass()); - pass_manager->Run(program); - } - } else { - if (FLAGS_enable_fusion_fallback) { - VLOG(0) << "Enable Fusion Fallback Pass"; - pass_manager->AddPass(cinn::dialect::ir::CreateFusionFallbackPass()); - } - if (has_dynamic_shape && !force_static_shape) { - pass_manager->AddPass( - cinn::dialect::ir::CreateLowerCinnDyShapeFusionOpPass()); - } else { - pass_manager->AddPass(cinn::dialect::ir::CreateLowerCinnFusionOpPass()); - } + VLOG(0) << "Enable AP Generic DRR Pass"; + ApplyApGenericDrrPass(program, CreatePassManager); + } + if (FLAGS_enable_fusion_fallback) { + VLOG(0) << "Enable Fusion Fallback Pass"; + pass_manager->AddPass(cinn::dialect::ir::CreateFusionFallbackPass()); + } + if (has_dynamic_shape && !force_static_shape) { pass_manager->AddPass( - cinn::dialect::ir::CreateSplitGenerateShapeIntoShapeOpsPass()); - pass_manager->Run(program); + cinn::dialect::ir::CreateLowerCinnDyShapeFusionOpPass()); + } else { + pass_manager->AddPass(cinn::dialect::ir::CreateLowerCinnFusionOpPass()); } + pass_manager->AddPass( + cinn::dialect::ir::CreateSplitGenerateShapeIntoShapeOpsPass()); + + pass_manager->Run(program); } template diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index da9f7e9680f612..f8f41d87ae47a1 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -1577,8 +1577,29 @@ PHI_DEFINE_EXPORTED_bool(logging_pir_py_code_dump_symbolic_dims, false, "whether dump symbolic dims into pir py code."); +/** + * Enable Abstract Pass + * Name: enable_ap + * Since Version: 3.0.0 + * Value Range: bool, default=false + * Example: + * Note: If True, abstract pass will be enabled to optimize performance. + */ PHI_DEFINE_EXPORTED_bool(enable_ap, false, "whether enable abstract pass."); +/** + * Enable Classic fused_gemm_epilogue when Abstract Pass is enabled. + * Name: ap_enable_classic_gemm_epilogue + * Since Version: 3.0.0 + * Value Range: bool, default=false + * Example: + * Note: If True, classic fused_gemm_epilogue will be enabled. + */ +PHI_DEFINE_EXPORTED_bool(ap_enable_classic_gemm_epilogue, + false, + "whether enable classic fused_gemm_epilogue when " + "abstract pass is enabled."); + PHI_DEFINE_EXPORTED_bool( pir_interpreter_record_stream_for_gc_cache, false, From 34d490a6f453eab409614133db8192f480b18637 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 20 Feb 2025 15:51:03 +0800 Subject: [PATCH 08/43] Fix compiling error in CI, including: using std::memcpy instead of reinterpret_cast to avoid strict-aliasing. --- paddle/ap/include/axpr/core_expr.h | 5 ++++- paddle/ap/include/axpr/data_type_method_class.h | 2 +- paddle/ap/include/axpr/data_value_util.h | 7 +++---- paddle/ap/src/drr/res_ptn_op_pattern_ctx_method_class.cc | 1 - 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/paddle/ap/include/axpr/core_expr.h b/paddle/ap/include/axpr/core_expr.h index e7e0d88a10c027..ed370f021f159c 100644 --- a/paddle/ap/include/axpr/core_expr.h +++ b/paddle/ap/include/axpr/core_expr.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -115,7 +116,9 @@ inline size_t GetHashValue(const Atomic& atomic) { [](const bool val) -> size_t { return val; }, [](const int64_t val) -> size_t { return val; }, [](const double val) -> size_t { - return *reinterpret_cast(&val); + size_t res; + std::memcpy(&res, &val, sizeof(double)); + return res; }, [](const std::string& val) -> size_t { return std::hash()(val); diff --git a/paddle/ap/include/axpr/data_type_method_class.h b/paddle/ap/include/axpr/data_type_method_class.h index 8afe0aa12685be..e71c9019f3801c 100644 --- a/paddle/ap/include/axpr/data_type_method_class.h +++ b/paddle/ap/include/axpr/data_type_method_class.h @@ -65,7 +65,7 @@ struct DataTypeMethodClass { builtin_symbol::NE>) { return &This::NE; } else { - std::nullopt; + return adt::Nothing{}; } } diff --git a/paddle/ap/include/axpr/data_value_util.h b/paddle/ap/include/axpr/data_value_util.h index b2de3fb1e5acfe..e64d0f217ecb46 100644 --- a/paddle/ap/include/axpr/data_value_util.h +++ b/paddle/ap/include/axpr/data_value_util.h @@ -67,12 +67,11 @@ struct ArithmeticBinaryOpHelper { return adt::errors::ZeroDivisionError{"division or modulo by zero"}; } return ArithmeticMod::Call(lhs, rhs); - } else if constexpr (!std::is_integral_v) { - return adt::errors::TypeError{ - "'%' only support intergral type. 'lhs' is not a intergral type"}; } else { return adt::errors::TypeError{ - "'%' only support intergral type. 'rhs' is not a intergral type"}; + std::string() + "'%' only support intergral type, but receive: '" + + CppDataType{}.Name() + "' and '" + CppDataType{}.Name() + + "'."}; } } }; diff --git a/paddle/ap/src/drr/res_ptn_op_pattern_ctx_method_class.cc b/paddle/ap/src/drr/res_ptn_op_pattern_ctx_method_class.cc index d2e29ddfd0707a..a73ee4a241ee78 100644 --- a/paddle/ap/src/drr/res_ptn_op_pattern_ctx_method_class.cc +++ b/paddle/ap/src/drr/res_ptn_op_pattern_ctx_method_class.cc @@ -95,7 +95,6 @@ struct ResPtnOpPatternCtxMethodClass { ADT_CHECK(args.size() == 1); const auto& arg = args.at(0); ADT_LET_CONST_REF(self, self_val.template CastTo()); - using RetT = adt::Result; ADT_LET_CONST_REF(attr_name, arg.template CastTo()); ADT_CHECK(!This{}.IsBasicAttrName(attr_name)) << adt::errors::RuntimeError{ std::string() + "Dead code encounterred. attr_name: " + attr_name}; From 4832e4cd8917e031bac958dbd378c4a28f303a06 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 21 Feb 2025 11:54:09 +0800 Subject: [PATCH 09/43] Fix narrowing conversion error and unused value error. --- paddle/ap/include/drr/drr_graph_descriptor.h | 8 ++++---- paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h | 12 ++++++------ paddle/ap/include/graph/node_list.h | 2 +- paddle/ap/include/ir_match/topo_match_ctx.h | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/paddle/ap/include/drr/drr_graph_descriptor.h b/paddle/ap/include/drr/drr_graph_descriptor.h index 1abd874bdc827f..f7785864256755 100644 --- a/paddle/ap/include/drr/drr_graph_descriptor.h +++ b/paddle/ap/include/drr/drr_graph_descriptor.h @@ -155,13 +155,13 @@ struct DefaultDrrGraphDescriptor { [](const DrrPackedIrValue&) -> adt::Result { return true; }, [&](const DrrPackedIrOpOperand& impl) -> adt::Result { ADT_LET_CONST_REF(upstreams, impl->node.UpstreamNodes()); - ADT_CHECK(upstreams.size(), 1); + ADT_CHECK(upstreams.size() == 1); ADT_LET_CONST_REF(upstream_node, upstreams.Sole()); return IgnoredNode(upstream_node); }, [&](const DrrPackedIrOpResult& impl) -> adt::Result { ADT_LET_CONST_REF(downstreams, impl->node.DownstreamNodes()); - ADT_CHECK(downstreams.size(), 1); + ADT_CHECK(downstreams.size() == 1); ADT_LET_CONST_REF(downstream_node, downstreams.Sole()); return IgnoredNode(downstream_node); }, @@ -172,14 +172,14 @@ struct DefaultDrrGraphDescriptor { [](const DrrNativeIrOpOperand&) -> adt::Result { return false; }, [&](const DrrOptPackedIrOpOperand& impl) -> adt::Result { ADT_LET_CONST_REF(upstreams, impl->node.UpstreamNodes()); - ADT_CHECK(upstreams.size(), 1); + ADT_CHECK(upstreams.size() == 1); ADT_LET_CONST_REF(upstream_node, upstreams.Sole()); return IgnoredNode(upstream_node); }, [](const DrrNativeIrOpResult&) -> adt::Result { return false; }, [&](const DrrOptPackedIrOpResult& impl) -> adt::Result { ADT_LET_CONST_REF(downstreams, impl->node.DownstreamNodes()); - ADT_CHECK(downstreams.size(), 1); + ADT_CHECK(downstreams.size() == 1); ADT_LET_CONST_REF(downstream_node, downstreams.Sole()); return IgnoredNode(downstream_node); }); diff --git a/paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h b/paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h index 21e5f2abde7aa6..dd0d14725f4317 100644 --- a/paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h +++ b/paddle/ap/include/drr/op_tensor_pattern_ctx_helper.h @@ -68,7 +68,7 @@ struct OpTensorPatternCtxHelper { op_pattern_ctx, adt::WeakPtrLock(native_ir_op->op_declare->op_pattern_ctx)); const auto& node_arena = op_pattern_ctx->node_arena; - for (int i = 0; i < inputs->size(); ++i) { + for (size_t i = 0; i < inputs->size(); ++i) { const auto& native_ir_op_operand = node_arena->New([&](const auto& node) { return NativeIrOpOperand{node, i}; }); @@ -81,7 +81,7 @@ struct OpTensorPatternCtxHelper { graph::IndexedTag{}, graph::IndexedTag{})); } - for (int i = 0; i < outputs->size(); ++i) { + for (size_t i = 0; i < outputs->size(); ++i) { ADT_LET_CONST_REF(output_upstream_nodes, outputs->at(i)->node.UpstreamNodes()); ADT_CHECK(output_upstream_nodes.size() == 0); @@ -114,7 +114,7 @@ struct OpTensorPatternCtxHelper { op_pattern_ctx, adt::WeakPtrLock(packed_ir_op->op_declare->op_pattern_ctx)); const auto& node_arena = op_pattern_ctx->node_arena; - for (int i = 0; i < inputs->size(); ++i) { + for (size_t i = 0; i < inputs->size(); ++i) { const auto& packed_ir_op_operand = node_arena->New([&](const auto& node) { return PackedIrOpOperand{node, i}; }); @@ -127,7 +127,7 @@ struct OpTensorPatternCtxHelper { graph::IndexedTag{}, graph::UnindexedTag{})); } - for (int i = 0; i < outputs->size(); ++i) { + for (size_t i = 0; i < outputs->size(); ++i) { ADT_LET_CONST_REF(output_upstream_nodes, outputs->at(i).node().UpstreamNodes()); ADT_CHECK(output_upstream_nodes.size() == 0); @@ -160,7 +160,7 @@ struct OpTensorPatternCtxHelper { op_pattern_ctx, adt::WeakPtrLock(packed_ir_op->op_declare->op_pattern_ctx)); const auto& node_arena = op_pattern_ctx->node_arena; - for (int i = 0; i < inputs->size(); ++i) { + for (size_t i = 0; i < inputs->size(); ++i) { const auto& packed_ir_op_operand = node_arena->New([&](const auto& node) { return OptPackedIrOpOperand{node, i}; }); @@ -173,7 +173,7 @@ struct OpTensorPatternCtxHelper { graph::IndexedTag{}, graph::UnindexedTag{})); } - for (int i = 0; i < outputs->size(); ++i) { + for (size_t i = 0; i < outputs->size(); ++i) { ADT_LET_CONST_REF(output_upstream_nodes, outputs->at(i).node().UpstreamNodes()); ADT_CHECK(output_upstream_nodes.size() == 0); diff --git a/paddle/ap/include/graph/node_list.h b/paddle/ap/include/graph/node_list.h index 56a0553daa4370..a2a1c693a80b3f 100644 --- a/paddle/ap/include/graph/node_list.h +++ b/paddle/ap/include/graph/node_list.h @@ -84,7 +84,7 @@ struct NodeList : public ListTag>> { return adt::errors::TypeError{"UndefinedList has no sole data"}; }, [](const auto& l) -> adt::Result> { - ADT_CHECK(l.data->size(), 1); + ADT_CHECK(l.data->size() == 1); return l.data->at(0); }); } diff --git a/paddle/ap/include/ir_match/topo_match_ctx.h b/paddle/ap/include/ir_match/topo_match_ctx.h index dbdbc7aa144082..748317231ff464 100644 --- a/paddle/ap/include/ir_match/topo_match_ctx.h +++ b/paddle/ap/include/ir_match/topo_match_ctx.h @@ -43,7 +43,7 @@ struct TopoMatchCtxImpl { adt::Result GetSoleBigGraphNode(const sg_node_t& node) const { ADT_LET_CONST_REF(bg_nodes, GetBigGraphNodes(node)); - ADT_CHECK(bg_nodes->size(), 1); + ADT_CHECK(bg_nodes->size() == 1); return *bg_nodes->begin(); } From 4db56bc37748e0e51f1e0f673a164d05f333aac7 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 21 Feb 2025 14:01:06 +0800 Subject: [PATCH 10/43] Fix missing-field-initializers and unused-result error on CI. --- paddle/ap/include/adt/adt.h | 12 ++++++------ paddle/ap/include/axpr/builtin_func_name_mgr.h | 2 +- paddle/ap/include/axpr/function.h | 2 +- paddle/ap/include/axpr/module_mgr.h | 2 +- paddle/ap/include/code_gen/code_gen_ctx.h | 2 +- .../code_gen/out_tensor_data_ptr_kernel_arg_id.h | 2 +- paddle/ap/include/drr/drr_ctx.h | 10 +++++----- paddle/ap/include/drr/native_ir_op_declare.h | 2 +- paddle/ap/include/drr/opt_packed_ir_op_declare.h | 2 +- paddle/ap/include/drr/packed_ir_op_declare.h | 2 +- paddle/ap/include/drr/result_pattern_ctx.h | 2 +- .../include/drr/src_ptn_packed_ir_op_declare_data.h | 4 ++-- paddle/ap/include/drr/tensor_pattern_ctx.h | 4 ++-- .../drr/res_ptn_unbound_native_ir_op_method_class.cc | 3 ++- .../drr/res_ptn_unbound_packed_ir_op_method_class.cc | 3 ++- .../drr/src_ptn_unbound_native_ir_op_method_class.cc | 3 ++- 16 files changed, 30 insertions(+), 27 deletions(-) diff --git a/paddle/ap/include/adt/adt.h b/paddle/ap/include/adt/adt.h index 70c0d029f9516d..f6a595d95cc54c 100644 --- a/paddle/ap/include/adt/adt.h +++ b/paddle/ap/include/adt/adt.h @@ -539,12 +539,12 @@ adt::Result> WeakPtrLock(const std::weak_ptr& weak_ptr) { }()) // clang-format off -#define ADT_CHECK(...) /* NOLINT */ \ - if (!(__VA_ARGS__)) /* NOLINT */ \ - return ::ap::adt::errors::Error{::ap::adt::errors::ValueError{ /* NOLINT */ \ - "Check '" #__VA_ARGS__ "' failed." /* NOLINT */ \ - }} << ADT_CURRENT_CODE_LOCATION( /* NOLINT */ \ - __FILE__, __LINE__, __FUNCTION__, #__VA_ARGS__ /* NOLINT */ \ +#define ADT_CHECK(...) /* NOLINT */ \ + if (!(__VA_ARGS__)) /* NOLINT */ \ + return ::ap::adt::errors::Error{::ap::adt::errors::ValueError{ /* NOLINT */ \ + "Check '" #__VA_ARGS__ "' failed." /* NOLINT */ \ + }} << ADT_CURRENT_CODE_LOCATION( /* NOLINT */ \ + __FILE__, __LINE__, __FUNCTION__, #__VA_ARGS__ /* NOLINT */ \ ) // clang-format on diff --git a/paddle/ap/include/axpr/builtin_func_name_mgr.h b/paddle/ap/include/axpr/builtin_func_name_mgr.h index 909ec2aeabbda1..53feedf5a8bd6a 100644 --- a/paddle/ap/include/axpr/builtin_func_name_mgr.h +++ b/paddle/ap/include/axpr/builtin_func_name_mgr.h @@ -21,7 +21,7 @@ namespace ap::axpr { struct BuiltinFuncName { - std::optional module_name; + std::optional module_name{}; std::string func_name; std::string ToString() const { diff --git a/paddle/ap/include/axpr/function.h b/paddle/ap/include/axpr/function.h index 22405e2f0a9df4..12c376dc3003b5 100644 --- a/paddle/ap/include/axpr/function.h +++ b/paddle/ap/include/axpr/function.h @@ -24,7 +24,7 @@ namespace ap::axpr { template struct FunctionImpl { Lambda lambda; - std::optional> global_frame; + std::optional> global_frame{}; bool operator==(const FunctionImpl& other) const { return this == &other; } diff --git a/paddle/ap/include/axpr/module_mgr.h b/paddle/ap/include/axpr/module_mgr.h index 4343103890dd42..59d084a0386cc5 100644 --- a/paddle/ap/include/axpr/module_mgr.h +++ b/paddle/ap/include/axpr/module_mgr.h @@ -163,7 +163,7 @@ class ModuleMgr { struct ApBuiltinModuleBuilder { std::string module_name; - axpr::AttrMap attr_map; + axpr::AttrMap attr_map{}; void Def(const std::string& name, const axpr::BuiltinFuncType& func) { diff --git a/paddle/ap/include/code_gen/code_gen_ctx.h b/paddle/ap/include/code_gen/code_gen_ctx.h index bb87bbe6d14368..6a2086b9fe7075 100644 --- a/paddle/ap/include/code_gen/code_gen_ctx.h +++ b/paddle/ap/include/code_gen/code_gen_ctx.h @@ -29,7 +29,7 @@ namespace ap::code_gen { template struct CodeGenCtxImpl { - std::optional> ir_match_ctx; + std::optional> ir_match_ctx{}; using DrrNode = drr::Node; using DrrPackedIrOp = drr::PackedIrOp; diff --git a/paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id.h b/paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id.h index fca622f0ec4983..65e82e4df3eaa0 100644 --- a/paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id.h +++ b/paddle/ap/include/code_gen/out_tensor_data_ptr_kernel_arg_id.h @@ -25,7 +25,7 @@ namespace ap::code_gen { template struct OutTensorDataPtrKernelArgIdImpl { BirNode ir_value; - std::optional> runtime_getter; + std::optional> runtime_getter{}; bool operator==(const OutTensorDataPtrKernelArgIdImpl& other) const { return this->ir_value == other.ir_value; diff --git a/paddle/ap/include/drr/drr_ctx.h b/paddle/ap/include/drr/drr_ctx.h index f7ae5a39539051..dac7c6c4f68e40 100644 --- a/paddle/ap/include/drr/drr_ctx.h +++ b/paddle/ap/include/drr/drr_ctx.h @@ -28,11 +28,11 @@ namespace ap::drr { struct DrrCtxImpl { std::weak_ptr circlable_ref_list; - std::optional pass_name; - std::optional source_pattern_ctx; - std::optional result_pattern_ctx; - std::optional constraint_func; - std::optional drr_pass_type; + std::optional pass_name{}; + std::optional source_pattern_ctx{}; + std::optional result_pattern_ctx{}; + std::optional constraint_func{}; + std::optional drr_pass_type{}; adt::Result GetSourcePatternCtx() const { ADT_CHECK(this->source_pattern_ctx.has_value()); diff --git a/paddle/ap/include/drr/native_ir_op_declare.h b/paddle/ap/include/drr/native_ir_op_declare.h index c287adafc996f0..ba62f82d1c07d7 100644 --- a/paddle/ap/include/drr/native_ir_op_declare.h +++ b/paddle/ap/include/drr/native_ir_op_declare.h @@ -28,7 +28,7 @@ template struct NativeIrOpDeclareImpl { std::string op_name; std::weak_ptr op_pattern_ctx; - axpr::AttrMap attr_map; + axpr::AttrMap attr_map{}; bool operator==(const NativeIrOpDeclareImpl& other) const { return this->op_name == other.op_name && diff --git a/paddle/ap/include/drr/opt_packed_ir_op_declare.h b/paddle/ap/include/drr/opt_packed_ir_op_declare.h index 6a2703e3b6bc79..928a0f2fdc75d4 100644 --- a/paddle/ap/include/drr/opt_packed_ir_op_declare.h +++ b/paddle/ap/include/drr/opt_packed_ir_op_declare.h @@ -30,7 +30,7 @@ template struct OptPackedIrOpDeclareImpl { std::string op_name; std::weak_ptr op_pattern_ctx; - std::optional> data; + std::optional> data{}; bool operator==(const OptPackedIrOpDeclareImpl& other) const { return this->op_name == other.op_name && diff --git a/paddle/ap/include/drr/packed_ir_op_declare.h b/paddle/ap/include/drr/packed_ir_op_declare.h index 05a9af558370a1..287a85d6b70a72 100644 --- a/paddle/ap/include/drr/packed_ir_op_declare.h +++ b/paddle/ap/include/drr/packed_ir_op_declare.h @@ -30,7 +30,7 @@ template struct PackedIrOpDeclareImpl { std::string op_name; std::weak_ptr op_pattern_ctx; - std::optional> data; + std::optional> data{}; bool operator==(const PackedIrOpDeclareImpl& other) const { return this->op_name == other.op_name && diff --git a/paddle/ap/include/drr/result_pattern_ctx.h b/paddle/ap/include/drr/result_pattern_ctx.h index 73f0830a5c271a..84b338abdb2da2 100644 --- a/paddle/ap/include/drr/result_pattern_ctx.h +++ b/paddle/ap/include/drr/result_pattern_ctx.h @@ -30,7 +30,7 @@ struct ResultPatternCtxImpl { OpPatternCtx op_pattern_ctx; TensorPatternCtx tensor_pattern_ctx; SourcePatternCtx source_pattern_ctx; - std::unordered_set internal_native_ir_value_names; + std::unordered_set internal_native_ir_value_names{}; bool operator==(const ResultPatternCtxImpl& other) const { return this != &other; diff --git a/paddle/ap/include/drr/src_ptn_packed_ir_op_declare_data.h b/paddle/ap/include/drr/src_ptn_packed_ir_op_declare_data.h index f5f0fdda7c8ecc..7373dd4ea75357 100644 --- a/paddle/ap/include/drr/src_ptn_packed_ir_op_declare_data.h +++ b/paddle/ap/include/drr/src_ptn_packed_ir_op_declare_data.h @@ -26,9 +26,9 @@ struct SrcPtnPackedIrOpDeclareData : public PackedIrOpDeclareData { SrcPtnPackedIrOpDeclareData() : PackedIrOpDeclareData() {} std::optional> - inner_source_pattern_func; + inner_source_pattern_func{}; - std::optional inner_source_pattern_ctx; + std::optional inner_source_pattern_ctx{}; }; } // namespace ap::drr diff --git a/paddle/ap/include/drr/tensor_pattern_ctx.h b/paddle/ap/include/drr/tensor_pattern_ctx.h index f6bd50ea944e0e..3f0c9698998298 100644 --- a/paddle/ap/include/drr/tensor_pattern_ctx.h +++ b/paddle/ap/include/drr/tensor_pattern_ctx.h @@ -28,9 +28,9 @@ struct DrrCtxImpl; struct TensorPatternCtxImpl { std::shared_ptr> node_arena; - mutable std::map uid2ir_value; + mutable std::map uid2ir_value{}; std::weak_ptr drr_ctx; - mutable std::map uid2type; + mutable std::map uid2type{}; bool operator==(const TensorPatternCtxImpl& other) const { return this == &other; diff --git a/paddle/ap/src/drr/res_ptn_unbound_native_ir_op_method_class.cc b/paddle/ap/src/drr/res_ptn_unbound_native_ir_op_method_class.cc index 93f009c4f3045d..426da775a7ce64 100644 --- a/paddle/ap/src/drr/res_ptn_unbound_native_ir_op_method_class.cc +++ b/paddle/ap/src/drr/res_ptn_unbound_native_ir_op_method_class.cc @@ -101,7 +101,8 @@ struct ResPtnUnboundNativeIrOpMethodClass { ADT_RETURN_IF_ERR(CheckNoRedundantTensorNames(inputs, outputs)); ADT_LET_CONST_REF(native_op, Helper{}.GetNativeIrOpByUnboundNativeIrOp(self.value())); - Helper{}.ConnectIrOpAndIrValue(native_op, inputs, outputs); + ADT_RETURN_IF_ERR( + Helper{}.ConnectIrOpAndIrValue(native_op, inputs, outputs)); return adt::Nothing{}; } diff --git a/paddle/ap/src/drr/res_ptn_unbound_packed_ir_op_method_class.cc b/paddle/ap/src/drr/res_ptn_unbound_packed_ir_op_method_class.cc index 0e808ff3bd2568..7b213d0c16efb0 100644 --- a/paddle/ap/src/drr/res_ptn_unbound_packed_ir_op_method_class.cc +++ b/paddle/ap/src/drr/res_ptn_unbound_packed_ir_op_method_class.cc @@ -77,7 +77,8 @@ struct ResPtnUnboundPackedIrOp { ADT_RETURN_IF_ERR(CheckNoRedundantTensorNames(inputs, outputs)); ADT_LET_CONST_REF(packed_op, Helper{}.GetPackedIrOpByUnboundPackedIrOp(self.value())); - Helper{}.ConnectIrOpAndIrValue(packed_op, inputs, outputs); + ADT_RETURN_IF_ERR( + Helper{}.ConnectIrOpAndIrValue(packed_op, inputs, outputs)); return adt::Nothing{}; } diff --git a/paddle/ap/src/drr/src_ptn_unbound_native_ir_op_method_class.cc b/paddle/ap/src/drr/src_ptn_unbound_native_ir_op_method_class.cc index e05d83b8a24e71..b2a8a821d62bed 100644 --- a/paddle/ap/src/drr/src_ptn_unbound_native_ir_op_method_class.cc +++ b/paddle/ap/src/drr/src_ptn_unbound_native_ir_op_method_class.cc @@ -119,7 +119,8 @@ struct SrcPtnUnboundNativeIrOp { ADT_LET_CONST_REF(native_outputs, ConvertOutputs(outputs)); ADT_LET_CONST_REF(native_op, Helper{}.GetNativeIrOpByUnboundNativeIrOp(self.value())); - Helper{}.ConnectIrOpAndIrValue(native_op, native_inputs, native_outputs); + ADT_RETURN_IF_ERR(Helper{}.ConnectIrOpAndIrValue( + native_op, native_inputs, native_outputs)); return adt::Nothing{}; } From a441f80338bb37225a9eb4f4ed8b8bee4213abbc Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 21 Feb 2025 19:55:01 +0800 Subject: [PATCH 11/43] Fix some sign-compare error on CI. --- paddle/ap/include/axpr/anf_expr_util.h | 2 +- paddle/ap/include/axpr/binary_func.h | 3 ++- paddle/ap/include/axpr/builtin_symbol.h | 1 + paddle/ap/include/axpr/cps_interpreter.h | 2 +- paddle/ap/include/axpr/list_method_class.h | 2 +- paddle/ap/include/axpr/serializable_list_method_class.h | 2 +- .../ap/include/code_module/api_wrapper_project_maker.h | 2 +- ...t_std_vector_const_meta_tensor_ptr_ptr_method_class.h | 2 +- .../paddle/std_vector_meta_tensor_ptr_ptr_method_class.h | 2 +- paddle/ap/include/rt_module/dl_function.h | 1 + paddle/ap/include/rt_module/function_helper.h | 2 +- paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc | 9 +++++---- paddle/ap/src/paddle/phi/ap_variadic_kernel.cc | 6 +++--- 13 files changed, 20 insertions(+), 16 deletions(-) diff --git a/paddle/ap/include/axpr/anf_expr_util.h b/paddle/ap/include/axpr/anf_expr_util.h index c9f3bec3b30f85..be8aec85a2af8d 100644 --- a/paddle/ap/include/axpr/anf_expr_util.h +++ b/paddle/ap/include/axpr/anf_expr_util.h @@ -115,7 +115,7 @@ struct CoreExprToAnfExprConverter { const auto& outer_func = composed_call->outer_func; return outer_func.Match( [&](const Lambda& lambda) -> AnfExpr { - CHECK_EQ(lambda->args.size(), 1); + CHECK_EQ(lambda->args.size(), 1U); const auto& val = ConvertComposedCallToCombined(composed_call); Bind binding{lambda->args.at(0), val}; bindings->emplace_back(std::move(binding)); diff --git a/paddle/ap/include/axpr/binary_func.h b/paddle/ap/include/axpr/binary_func.h index 66abcb90463e46..3a2b7f3308ebb7 100644 --- a/paddle/ap/include/axpr/binary_func.h +++ b/paddle/ap/include/axpr/binary_func.h @@ -37,7 +37,8 @@ namespace ap::axpr { \ template \ static auto Call(const LhsT& lhs, const RhsT& rhs) { \ - return lhs op rhs; \ + using T = std::common_type_t; \ + return static_cast(lhs) op static_cast(rhs); \ } \ }; PEXPR_FOR_EACH_BINARY_OP(DEFINE_ARITHMETIC_BINARY_OP); diff --git a/paddle/ap/include/axpr/builtin_symbol.h b/paddle/ap/include/axpr/builtin_symbol.h index 3b0c3353ed4626..551d43ad462376 100644 --- a/paddle/ap/include/axpr/builtin_symbol.h +++ b/paddle/ap/include/axpr/builtin_symbol.h @@ -14,6 +14,7 @@ #pragma once +#include #include "paddle/ap/include/axpr/adt.h" #include "paddle/ap/include/axpr/binary_func.h" #include "paddle/ap/include/axpr/type.h" diff --git a/paddle/ap/include/axpr/cps_interpreter.h b/paddle/ap/include/axpr/cps_interpreter.h index 52af21bc696dc1..6bc4fbbd54fd34 100644 --- a/paddle/ap/include/axpr/cps_interpreter.h +++ b/paddle/ap/include/axpr/cps_interpreter.h @@ -414,7 +414,7 @@ class CpsInterpreter : public InterpreterBase { ADT_LET_CONST_REF(packed_args, packed.template TryGet>()); const auto& [pos_args, kwargs] = *packed_args; - int lambda_arg_idx = (self.has_value() ? 1 : 0); + size_t lambda_arg_idx = (self.has_value() ? 1 : 0); ADT_CHECK(lambda_arg_idx + pos_args->size() <= lambda->args.size()) << TypeError{std::string("() takes ") + std::to_string(lambda->args.size()) + diff --git a/paddle/ap/include/axpr/list_method_class.h b/paddle/ap/include/axpr/list_method_class.h index 6d2a72ed2af984..e83c13c2fcfc97 100644 --- a/paddle/ap/include/axpr/list_method_class.h +++ b/paddle/ap/include/axpr/list_method_class.h @@ -120,7 +120,7 @@ struct MethodClassImpl> { if (index < 0) { index += self->size(); } - if (index >= 0 && index < self->size()) { + if (index >= 0 && index < static_cast(self->size())) { return self->at(index); } return adt::errors::IndexError{"list index out of range"}; diff --git a/paddle/ap/include/axpr/serializable_list_method_class.h b/paddle/ap/include/axpr/serializable_list_method_class.h index a11996586264ea..7192c50f5cc61f 100644 --- a/paddle/ap/include/axpr/serializable_list_method_class.h +++ b/paddle/ap/include/axpr/serializable_list_method_class.h @@ -46,7 +46,7 @@ struct MethodClassImpl> { if (index < 0) { index += self->size(); } - if (index >= 0 && index < self->size()) { + if (index >= 0 && index < static_cast(self->size())) { return self->at(index).template CastTo(); } return adt::errors::IndexError{"list index out of range"}; diff --git a/paddle/ap/include/code_module/api_wrapper_project_maker.h b/paddle/ap/include/code_module/api_wrapper_project_maker.h index 0e3b0ce24ea165..51d9cbc60188a7 100644 --- a/paddle/ap/include/code_module/api_wrapper_project_maker.h +++ b/paddle/ap/include/code_module/api_wrapper_project_maker.h @@ -82,7 +82,7 @@ struct ApiWrapperProjectMaker { const std::string& args_var_name) { std::ostringstream ss; ss << func_var_name << "("; - for (int i = 0; i < func_declare->arg_types->size(); ++i) { + for (size_t i = 0; i < func_declare->arg_types->size(); ++i) { if (i > 0) { ss << ", "; } diff --git a/paddle/ap/include/paddle/const_std_vector_const_meta_tensor_ptr_ptr_method_class.h b/paddle/ap/include/paddle/const_std_vector_const_meta_tensor_ptr_ptr_method_class.h index cf964b87c80f10..5de27f52d8427d 100644 --- a/paddle/ap/include/paddle/const_std_vector_const_meta_tensor_ptr_ptr_method_class.h +++ b/paddle/ap/include/paddle/const_std_vector_const_meta_tensor_ptr_ptr_method_class.h @@ -53,7 +53,7 @@ struct ConstStdVectorConstMetaTensorPtrPtrMethodClass { if (index < 0) { index += self->size(); } - if (index >= 0 && index < self->size()) { + if (index >= 0 && index < static_cast(self->size())) { return CastItem(self->at(index)); } return adt::errors::IndexError{"vector index out of range"}; diff --git a/paddle/ap/include/paddle/std_vector_meta_tensor_ptr_ptr_method_class.h b/paddle/ap/include/paddle/std_vector_meta_tensor_ptr_ptr_method_class.h index df384ef82dd055..8c8f3a270d5d75 100644 --- a/paddle/ap/include/paddle/std_vector_meta_tensor_ptr_ptr_method_class.h +++ b/paddle/ap/include/paddle/std_vector_meta_tensor_ptr_ptr_method_class.h @@ -53,7 +53,7 @@ struct StdVectorMetaTensorPtrPtrMethodClass { if (index < 0) { index += self->size(); } - if (index >= 0 && index < self->size()) { + if (index >= 0 && index < static_cast(self->size())) { return CastItem(self->at(index)); } return adt::errors::IndexError{"vector index out of range"}; diff --git a/paddle/ap/include/rt_module/dl_function.h b/paddle/ap/include/rt_module/dl_function.h index 3c12d393c5c6cd..7e074b6f8227c7 100644 --- a/paddle/ap/include/rt_module/dl_function.h +++ b/paddle/ap/include/rt_module/dl_function.h @@ -41,6 +41,7 @@ class DlFunction { adt::Result Apply(void* ret, void** args) const { ADT_LET_CONST_REF(dl_handle_guard, adt::WeakPtrLock(dl_handle_)); api_wrapper_(ret, func_, args); + (void)dl_handle_guard; return adt::Ok{}; } diff --git a/paddle/ap/include/rt_module/function_helper.h b/paddle/ap/include/rt_module/function_helper.h index 4255d70909830f..67ae111a48603c 100644 --- a/paddle/ap/include/rt_module/function_helper.h +++ b/paddle/ap/include/rt_module/function_helper.h @@ -35,7 +35,7 @@ struct FunctionHelper { std::to_string(func_declare->arg_types->size()) + " arguments, but " + std::to_string(args.size()) + " were given"}; - for (int i = 0; i < args.size(); ++i) { + for (size_t i = 0; i < args.size(); ++i) { const auto& arg_axpr_value = args.at(i); { // check arg type diff --git a/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc b/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc index a628b3e6104ce0..ff85663d55d569 100644 --- a/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc +++ b/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc @@ -2275,10 +2275,11 @@ class NativeOpAnchorApGenericDrrPattern : public pir::RewritePattern { } const auto& ret = this->TryMatchAndRewrite(op, &rewriter); if (ret.HasError()) { - LOG(ERROR) << "\nTraceback (most recent call last):\n" - << ret.GetError().CallStackToString() << "\n" - << ret.GetError().class_name() << ": " << ret.GetError().msg() - << "\npass_name: " << ctx_.drr_ctx->pass_name.value(); + VLOG(0) << "drr: " << ctx_.drr_ctx->pass_name.value() << " mismatched."; + VLOG(6) << "\nTraceback (most recent call last):\n" + << ret.GetError().CallStackToString() << "\n" + << ret.GetError().class_name() << ": " << ret.GetError().msg() + << "\npass_name: " << ctx_.drr_ctx->pass_name.value(); return false; } bool success = ret.GetOkValue(); diff --git a/paddle/ap/src/paddle/phi/ap_variadic_kernel.cc b/paddle/ap/src/paddle/phi/ap_variadic_kernel.cc index 14b44434e0b57d..38de37722b5c0e 100644 --- a/paddle/ap/src/paddle/phi/ap_variadic_kernel.cc +++ b/paddle/ap/src/paddle/phi/ap_variadic_kernel.cc @@ -162,7 +162,7 @@ adt::Result VisitTensorIdxOrRange( const DoEachIdxT& DoEachIdx, const DoEachRangeT& DoEachRange) { using Ok = adt::Result; - for (int i = 0; i < list->size(); ++i) { + for (size_t i = 0; i < list->size(); ++i) { const auto& elt = list->at(i); ADT_RETURN_IF_ERR(elt.Match( [&](int64_t idx) -> Ok { @@ -214,7 +214,7 @@ adt::Result> MakeConstTensors( ADT_CHECK(start <= end); adt::List tensor_list; tensor_list->reserve(end - start); - for (int i = start; i < end; ++i) { + for (size_t i = start; i < end; ++i) { ADT_CHECK(i < xs.size()); const auto* x = xs.at(i); ADT_RETURN_IF_ERR(CollectTensor(&tensor_list, x)); @@ -258,7 +258,7 @@ adt::Result> MakeMutableTensors( ADT_CHECK(start <= end); adt::List tensor_list; tensor_list->reserve(end - start); - for (int i = start; i < end; ++i) { + for (size_t i = start; i < end; ++i) { ADT_CHECK(i < xs.size()); auto* x = xs.at(i); ADT_RETURN_IF_ERR(CollectTensor(&tensor_list, x)); From 60676fe7d71792a7bf620899b40963f85893b145 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 25 Feb 2025 10:54:49 +0800 Subject: [PATCH 12/43] Add cmake dependent. --- paddle/ap/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/ap/CMakeLists.txt b/paddle/ap/CMakeLists.txt index a07ebaed5507ed..5c84e1cc81dba2 100644 --- a/paddle/ap/CMakeLists.txt +++ b/paddle/ap/CMakeLists.txt @@ -41,7 +41,7 @@ cc_library( DEPS ${ap_phi_deps}) file(GLOB_RECURSE ap_pir_srcs "src/paddle/pir/*.cc") -set(ap_pir_deps axpr ap_drr) +set(ap_pir_deps axpr ap_drr absl) cc_library( ap_pir SRCS ${ap_pir_srcs} From 6b786d9d0105f0a4f5e1e3f98d3f42338c0993cb Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 25 Feb 2025 11:13:21 +0800 Subject: [PATCH 13/43] Fix some using statement without creating an alias. --- paddle/ap/include/axpr/error.h | 3 --- paddle/ap/include/axpr/instance_attrs.h | 2 +- paddle/ap/include/axpr/mutable_list.h | 2 +- paddle/ap/include/axpr/mutable_ordered_dict.h | 2 +- 4 files changed, 3 insertions(+), 6 deletions(-) diff --git a/paddle/ap/include/axpr/error.h b/paddle/ap/include/axpr/error.h index 78c09da4d50172..4098c99cf979ed 100644 --- a/paddle/ap/include/axpr/error.h +++ b/paddle/ap/include/axpr/error.h @@ -30,7 +30,4 @@ using adt::errors::TypeError; using adt::errors::ValueError; using adt::errors::ZeroDivisionError; -template -using Result = adt::Result; - } // namespace ap::axpr diff --git a/paddle/ap/include/axpr/instance_attrs.h b/paddle/ap/include/axpr/instance_attrs.h index 15ca2ab5ac7340..dfdc7bac014e02 100644 --- a/paddle/ap/include/axpr/instance_attrs.h +++ b/paddle/ap/include/axpr/instance_attrs.h @@ -24,7 +24,7 @@ template struct InstanceAttrs : public memory::CirclableRef, AttrMapImpl> { using Base = memory::CirclableRef, AttrMapImpl>; - using Base::CirclableRef; + using typename Base::CirclableRef; }; } // namespace ap::axpr diff --git a/paddle/ap/include/axpr/mutable_list.h b/paddle/ap/include/axpr/mutable_list.h index de5b4fcfe7e131..c927b2b68b63c2 100644 --- a/paddle/ap/include/axpr/mutable_list.h +++ b/paddle/ap/include/axpr/mutable_list.h @@ -25,7 +25,7 @@ template struct MutableList : public memory::CirclableRef, std::vector> { using Base = memory::CirclableRef, std::vector>; - using Base::CirclableRef; + using typename Base::CirclableRef; }; template diff --git a/paddle/ap/include/axpr/mutable_ordered_dict.h b/paddle/ap/include/axpr/mutable_ordered_dict.h index 5ec74f9548e2ce..a9f1dd0e3b1cae 100644 --- a/paddle/ap/include/axpr/mutable_ordered_dict.h +++ b/paddle/ap/include/axpr/mutable_ordered_dict.h @@ -29,7 +29,7 @@ struct MutableOrderedDict MutableOrderedDictImpl> { using Base = memory::CirclableRef, MutableOrderedDictImpl>; - using Base::CirclableRef; + using typename Base::CirclableRef; }; template From f36dfe8d3ad9819f266c47800445d49d4b0b6f53 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 25 Feb 2025 13:45:15 +0800 Subject: [PATCH 14/43] Support experimental/type_traits for WIN32. --- paddle/ap/include/axpr/data_value_util.h | 4 ++-- paddle/ap/include/axpr/hash.h | 1 - paddle/ap/include/axpr/method_class.h | 30 ++++++++++++++++++------ 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/paddle/ap/include/axpr/data_value_util.h b/paddle/ap/include/axpr/data_value_util.h index e64d0f217ecb46..6cb4bd922f323c 100644 --- a/paddle/ap/include/axpr/data_value_util.h +++ b/paddle/ap/include/axpr/data_value_util.h @@ -70,8 +70,8 @@ struct ArithmeticBinaryOpHelper { } else { return adt::errors::TypeError{ std::string() + "'%' only support intergral type, but receive: '" + - CppDataType{}.Name() + "' and '" + CppDataType{}.Name() + - "'."}; + CppDataType{}.Name() + "' and '" + + CppDataType{}.Name() + "'."}; } } }; diff --git a/paddle/ap/include/axpr/hash.h b/paddle/ap/include/axpr/hash.h index f40b3580d5f236..fba1df876e5af8 100644 --- a/paddle/ap/include/axpr/hash.h +++ b/paddle/ap/include/axpr/hash.h @@ -14,7 +14,6 @@ #pragma once -#include #include "paddle/ap/include/adt/adt.h" #include "paddle/ap/include/axpr/method_class.h" diff --git a/paddle/ap/include/axpr/method_class.h b/paddle/ap/include/axpr/method_class.h index a9fe458a087d47..804ba6acb01398 100644 --- a/paddle/ap/include/axpr/method_class.h +++ b/paddle/ap/include/axpr/method_class.h @@ -14,7 +14,9 @@ #pragma once +#ifndef _WIN32 #include +#endif #include #include #include @@ -80,6 +82,20 @@ struct MethodClassImpl; namespace detail { +#ifndef _WIN32 +template