Skip to content

[SOT] Non-break support for paddle.get_device #72004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/paddle/device/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@ def get_device() -> str:
elif isinstance(place, core.IPUPlace):
num_devices = core.get_ipu_device_count()
device = f"ipus:{{0-{num_devices - 1}}}"
device = f"ipus:{{0-{num_devices - 1}}}"
elif isinstance(place, core.CustomPlace):
device_id = place.get_device_id()
device_type = place.get_device_type()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,11 @@ def tensor_dim(x):

def generator_send(x):
pass


def place_get_device_id():
pass
Copy link
Preview

Copilot AI Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function 'place_get_device_id' remains unimplemented. Add an appropriate implementation or provide a comment explaining why it is intentionally left as a stub.

Suggested change
pass
import paddle
device = paddle.device.get_device()
device_id = int(device.split(':')[-1]) if ':' in device else 0
return device_id

Copilot uses AI. Check for mistakes.



def place_get_device_type():
pass
Copy link
Preview

Copilot AI Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function 'place_get_device_id' is currently a stub (using 'pass'). If this function is meant to be implemented in this PR, consider raising a NotImplementedError or adding a TODO comment to clarify its intended behavior.

Suggested change
pass
raise NotImplementedError("place_get_device_id function is not yet implemented.")

Copilot uses AI. Check for mistakes.

Copy link
Preview

Copilot AI Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function 'place_get_device_id' is currently unimplemented. Provide an appropriate implementation or a clear fallback to ensure correct behavior when this function is invoked.

Suggested change
pass
if hasattr(x, 'device_id'):
return x.device_id
else:
raise FallbackError("Input does not have a device_id attribute.")

Copilot uses AI. Check for mistakes.

Comment on lines +68 to +69
Copy link
Preview

Copilot AI Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function 'place_get_device_type' remains unimplemented. Add an appropriate implementation or provide a comment explaining why it is intentionally left as a stub.

Suggested change
def place_get_device_type():
pass
def place_get_device_type(place):
if place.is_gpu_place():
return "GPU"
elif place.is_cpu_place():
return "CPU"
else:
raise ValueError("Unsupported place type")

Copilot uses AI. Check for mistakes.

Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
operator_is_none,
operator_is_not_none,
operator_not_in,
place_get_device_id,
place_get_device_type,
tensor_dim,
)
from .dispatcher import Dispatcher, optional
Expand Down Expand Up @@ -1586,3 +1588,15 @@ def dispatch_all(var: ContainerVariable | IterVariable):
ufunc,
),
)

# place
Dispatcher.register(
place_get_device_id,
("PlaceVariable",),
lambda var: var.get_device_id(),
)
Dispatcher.register(
place_get_device_type,
("PlaceVariable",),
lambda var: var.get_device_type(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
NumpyVariable,
ObjectVariable,
ParameterVariable,
PlaceVariable,
SliceVariable,
SuperVariable,
SymbolicVariable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np

import paddle
from paddle._typing import unreached
from paddle.framework import core

from ....infer_meta import (
Expand Down Expand Up @@ -54,7 +55,11 @@
InnerError,
UnsupportedPaddleAPIBreak,
)
from ..dispatch_functions import tensor_dim
from ..dispatch_functions import (
place_get_device_id,
place_get_device_type,
tensor_dim,
)
from ..guard import (
FasterStringifiedExpression,
StringifiedExpression,
Expand Down Expand Up @@ -1384,7 +1389,7 @@ def make_stringified_guard(self) -> None:

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if isinstance(value, (np.ndarray)):
if isinstance(value, np.ndarray):
return NumpyArrayVariable(value, graph, tracker)
return None

Expand Down Expand Up @@ -1439,7 +1444,7 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
class NumpyBoolVariable(NumpyNumberVariable):
@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if isinstance(value, (np.bool_)):
if isinstance(value, np.bool_):
return NumpyBoolVariable(value, graph, tracker)
return None

Expand Down Expand Up @@ -1469,6 +1474,54 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:
]


class PlaceVariable(ObjectVariable):
def __init__(self, obj, graph, tracker):
super().__init__(obj, graph, tracker)

def getattr(self, name: str, default=None):
if default is not None:
raise FallbackError(
"default argument for getattr is not implemented"
)
if name not in ["get_device_id", "get_device_type"]:
return super().getattr(name, default)
from .callable import BuiltinVariable

if name == "get_device_id":
return BuiltinVariable(
place_get_device_id, self.graph, DanglingTracker()
Copy link
Preview

Copilot AI Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DanglingTracker is used here without an explicit import. Please ensure that DanglingTracker is imported from its appropriate module to avoid a NameError.

Copilot uses AI. Check for mistakes.

).bind_dangling_fn(self, name)
elif name == "get_device_type":
return BuiltinVariable(
place_get_device_type, self.graph, DanglingTracker()
).bind_dangling_fn(self, name)
unreached()

def get_device_id(self):
return VariableFactory.from_value(
self.value.get_device_id(), self.graph, DummyTracker([self])
)

def get_device_type(self):
return VariableFactory.from_value(
self.value.get_device_type(), self.graph, DummyTracker([self])
)

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if paddle.is_compiled_with_cuda() and isinstance(
value, (paddle.CUDAPlace, paddle.CUDAPinnedPlace)
):
return PlaceVariable(value, graph, tracker)
if paddle.is_compiled_with_xpu() and isinstance(
value, (paddle.XPUPlace, paddle.XPUPinnedPlace)
):
return PlaceVariable(value, graph, tracker)
if isinstance(value, paddle.CustomPlace):
return PlaceVariable(value, graph, tracker)
return None


class NullVariable(VariableBase):
"""
NullVariable is a subclass of VariableBase used to represent a placeholder variable that has no value or reference associated with it.
Expand Down
104 changes: 104 additions & 0 deletions test/sot/test_sot_place.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest
from contextlib import contextmanager

from test_case_base import (
TestCaseBase,
test_instruction_translator_cache_context,
)

import paddle
from paddle.jit.sot.psdb import check_no_breakgraph


@contextmanager
def device_guard(place: str):
original_place = paddle.get_device()
try:
paddle.set_device(place)
yield
finally:
paddle.set_device(original_place)


@check_no_breakgraph
def run_diff_logic_by_check_expected_place(x: paddle.Tensor):
expected_place_str = paddle.get_device()
if "cpu" in expected_place_str:
return x + 1
elif "gpu" in expected_place_str:
return x + 2
elif "xpu" in expected_place_str:
return x + 3
elif "npu" in expected_place_str:
return x + 4
return x


class TestCheckExpectedPlace(TestCaseBase):
def test_check_cpu(self):
x = paddle.to_tensor(0.0)
with device_guard("cpu"):
self.assert_results(run_diff_logic_by_check_expected_place, x.cpu())

@unittest.skipUnless(
paddle.is_compiled_with_cuda(),
"This test case needs to be compiled with CUDA",
)
def test_check_gpu(self):
x = paddle.to_tensor(0.0)
with device_guard("gpu"):
self.assert_results(
run_diff_logic_by_check_expected_place, x.cuda()
)

@unittest.skipUnless(
paddle.is_compiled_with_xpu(),
"This test case needs to be compiled with XPU",
)
def test_check_xpu(self):
x = paddle.to_tensor(0.0)
with device_guard("xpu"):
self.assert_results(
run_diff_logic_by_check_expected_place, x.to("xpu")
)


class TestExpectedPlaceGuard(TestCaseBase):
@unittest.skipUnless(
paddle.is_compiled_with_cuda(),
"This test case needs to be compiled with cuda",
)
def test_expected_place_guard(self):
x = paddle.to_tensor(0.0)
with test_instruction_translator_cache_context() as ctx:
self.assertEqual(ctx.translate_count, 0)
with device_guard("cpu"):
self.assert_results(
run_diff_logic_by_check_expected_place, x.cpu()
)
self.assertEqual(ctx.translate_count, 1)
with device_guard("gpu"):
self.assert_results(
run_diff_logic_by_check_expected_place, x.cuda()
)
self.assertEqual(ctx.translate_count, 2)


if __name__ == "__main__":
unittest.main()
Loading