-
Notifications
You must be signed in to change notification settings - Fork 5.7k
[SOT] Non-break support for paddle.get_device
#72004
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e9de6eb
ef902db
665b40d
c789a7d
152fe82
3c9c0c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -59,3 +59,11 @@ def tensor_dim(x): | |||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def generator_send(x): | ||||||||||||||||||||||||||||||||||
pass | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def place_get_device_id(): | ||||||||||||||||||||||||||||||||||
pass | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def place_get_device_type(): | ||||||||||||||||||||||||||||||||||
pass | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function 'place_get_device_id' is currently a stub (using 'pass'). If this function is meant to be implemented in this PR, consider raising a NotImplementedError or adding a TODO comment to clarify its intended behavior.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function 'place_get_device_id' is currently unimplemented. Provide an appropriate implementation or a clear fallback to ensure correct behavior when this function is invoked.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback
Comment on lines
+68
to
+69
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function 'place_get_device_type' remains unimplemented. Add an appropriate implementation or provide a comment explaining why it is intentionally left as a stub.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
import numpy as np | ||
|
||
import paddle | ||
from paddle._typing import unreached | ||
from paddle.framework import core | ||
|
||
from ....infer_meta import ( | ||
|
@@ -54,7 +55,11 @@ | |
InnerError, | ||
UnsupportedPaddleAPIBreak, | ||
) | ||
from ..dispatch_functions import tensor_dim | ||
from ..dispatch_functions import ( | ||
place_get_device_id, | ||
place_get_device_type, | ||
tensor_dim, | ||
) | ||
from ..guard import ( | ||
FasterStringifiedExpression, | ||
StringifiedExpression, | ||
|
@@ -1384,7 +1389,7 @@ def make_stringified_guard(self) -> None: | |
|
||
@VariableFactory.register_from_value() | ||
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): | ||
if isinstance(value, (np.ndarray)): | ||
if isinstance(value, np.ndarray): | ||
return NumpyArrayVariable(value, graph, tracker) | ||
return None | ||
|
||
|
@@ -1439,7 +1444,7 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): | |
class NumpyBoolVariable(NumpyNumberVariable): | ||
@VariableFactory.register_from_value() | ||
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): | ||
if isinstance(value, (np.bool_)): | ||
if isinstance(value, np.bool_): | ||
return NumpyBoolVariable(value, graph, tracker) | ||
return None | ||
|
||
|
@@ -1469,6 +1474,54 @@ def make_stringified_guard(self) -> list[StringifiedExpression]: | |
] | ||
|
||
|
||
class PlaceVariable(ObjectVariable): | ||
def __init__(self, obj, graph, tracker): | ||
super().__init__(obj, graph, tracker) | ||
|
||
def getattr(self, name: str, default=None): | ||
if default is not None: | ||
raise FallbackError( | ||
"default argument for getattr is not implemented" | ||
) | ||
if name not in ["get_device_id", "get_device_type"]: | ||
return super().getattr(name, default) | ||
from .callable import BuiltinVariable | ||
|
||
if name == "get_device_id": | ||
return BuiltinVariable( | ||
place_get_device_id, self.graph, DanglingTracker() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DanglingTracker is used here without an explicit import. Please ensure that DanglingTracker is imported from its appropriate module to avoid a NameError. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
).bind_dangling_fn(self, name) | ||
elif name == "get_device_type": | ||
return BuiltinVariable( | ||
place_get_device_type, self.graph, DanglingTracker() | ||
).bind_dangling_fn(self, name) | ||
unreached() | ||
|
||
def get_device_id(self): | ||
return VariableFactory.from_value( | ||
self.value.get_device_id(), self.graph, DummyTracker([self]) | ||
) | ||
|
||
def get_device_type(self): | ||
return VariableFactory.from_value( | ||
self.value.get_device_type(), self.graph, DummyTracker([self]) | ||
) | ||
|
||
@VariableFactory.register_from_value() | ||
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): | ||
if paddle.is_compiled_with_cuda() and isinstance( | ||
value, (paddle.CUDAPlace, paddle.CUDAPinnedPlace) | ||
): | ||
return PlaceVariable(value, graph, tracker) | ||
if paddle.is_compiled_with_xpu() and isinstance( | ||
value, (paddle.XPUPlace, paddle.XPUPinnedPlace) | ||
): | ||
return PlaceVariable(value, graph, tracker) | ||
if isinstance(value, paddle.CustomPlace): | ||
return PlaceVariable(value, graph, tracker) | ||
return None | ||
|
||
|
||
class NullVariable(VariableBase): | ||
""" | ||
NullVariable is a subclass of VariableBase used to represent a placeholder variable that has no value or reference associated with it. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import unittest | ||
from contextlib import contextmanager | ||
|
||
from test_case_base import ( | ||
TestCaseBase, | ||
test_instruction_translator_cache_context, | ||
) | ||
|
||
import paddle | ||
from paddle.jit.sot.psdb import check_no_breakgraph | ||
|
||
|
||
@contextmanager | ||
def device_guard(place: str): | ||
original_place = paddle.get_device() | ||
try: | ||
paddle.set_device(place) | ||
yield | ||
finally: | ||
paddle.set_device(original_place) | ||
|
||
|
||
@check_no_breakgraph | ||
def run_diff_logic_by_check_expected_place(x: paddle.Tensor): | ||
expected_place_str = paddle.get_device() | ||
if "cpu" in expected_place_str: | ||
return x + 1 | ||
elif "gpu" in expected_place_str: | ||
return x + 2 | ||
elif "xpu" in expected_place_str: | ||
return x + 3 | ||
elif "npu" in expected_place_str: | ||
return x + 4 | ||
return x | ||
|
||
|
||
class TestCheckExpectedPlace(TestCaseBase): | ||
def test_check_cpu(self): | ||
x = paddle.to_tensor(0.0) | ||
with device_guard("cpu"): | ||
self.assert_results(run_diff_logic_by_check_expected_place, x.cpu()) | ||
|
||
@unittest.skipUnless( | ||
paddle.is_compiled_with_cuda(), | ||
"This test case needs to be compiled with CUDA", | ||
) | ||
def test_check_gpu(self): | ||
x = paddle.to_tensor(0.0) | ||
with device_guard("gpu"): | ||
self.assert_results( | ||
run_diff_logic_by_check_expected_place, x.cuda() | ||
) | ||
|
||
@unittest.skipUnless( | ||
paddle.is_compiled_with_xpu(), | ||
"This test case needs to be compiled with XPU", | ||
) | ||
def test_check_xpu(self): | ||
x = paddle.to_tensor(0.0) | ||
with device_guard("xpu"): | ||
self.assert_results( | ||
run_diff_logic_by_check_expected_place, x.to("xpu") | ||
) | ||
|
||
|
||
class TestExpectedPlaceGuard(TestCaseBase): | ||
@unittest.skipUnless( | ||
paddle.is_compiled_with_cuda(), | ||
"This test case needs to be compiled with cuda", | ||
) | ||
def test_expected_place_guard(self): | ||
x = paddle.to_tensor(0.0) | ||
with test_instruction_translator_cache_context() as ctx: | ||
self.assertEqual(ctx.translate_count, 0) | ||
with device_guard("cpu"): | ||
self.assert_results( | ||
run_diff_logic_by_check_expected_place, x.cpu() | ||
) | ||
self.assertEqual(ctx.translate_count, 1) | ||
with device_guard("gpu"): | ||
self.assert_results( | ||
run_diff_logic_by_check_expected_place, x.cuda() | ||
) | ||
self.assertEqual(ctx.translate_count, 2) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function 'place_get_device_id' remains unimplemented. Add an appropriate implementation or provide a comment explaining why it is intentionally left as a stub.
Copilot uses AI. Check for mistakes.