diff --git a/hmdriver2/_client.py b/hmdriver2/_client.py index 74a583e..8cb2385 100644 --- a/hmdriver2/_client.py +++ b/hmdriver2/_client.py @@ -1,231 +1,405 @@ # -*- coding: utf-8 -*- - -import socket +import os import json +import socket +import struct import time -import os -import typing -import subprocess import hashlib from datetime import datetime from functools import cached_property - -from . import logger -from .hdc import HdcWrapper -from .proto import HypiumResponse, DriverData from .exception import InvokeHypiumError, InvokeCaptures +from .hdc import HdcWrapper +from .proto import HypiumResponse + + +class SocketConfig: + PORT = 8012 + TIMEOUT = 5 + BUFFER_SIZE = 8192 + + +class MessageProtocol: + HEADER = b'_uitestkit_rpc_message_head_' # 消息头 + TAILER = b'_uitestkit_rpc_message_tail_' # 消息尾 + SESSION_ID_LENGTH = 4 + LENGTH_FIELD_LENGTH = 4 + HEADER_LENGTH = len(HEADER) + TAILER_LENGTH = len(TAILER) + FULL_HEADER_LENGTH = HEADER_LENGTH + SESSION_ID_LENGTH + LENGTH_FIELD_LENGTH + +LOCAL_HOST = "127.0.0.1" +API_MODULE = "com.ohos.devicetest.hypiumApiHelper" +API_METHOD_HYPIUM = "callHypiumApi" +API_METHOD_CAPTURES = "Captures" +DEFAULT_THIS = "Driver#0" -UITEST_SERVICE_PORT = 8012 -SOCKET_TIMEOUT = 20 + +class MessageCircularBuffer: + """环形缓冲区实现""" + + def __init__(self, initial_capacity: int = 65536, max_capacity: int = 10 * 1024 * 1024): + self.initial_capacity = initial_capacity + self.max_capacity = max_capacity + self.buffer = bytearray(initial_capacity) + self.capacity = initial_capacity + self.head = 0 + self.tail = 0 + self.size = 0 + + def _ensure_capacity(self, required: int) -> None: + if self.size + required <= self.capacity: + return + + new_capacity = self.capacity + while new_capacity < self.size + required: + new_capacity = min(new_capacity * 2, self.max_capacity) + + if new_capacity > self.capacity: + self._expand(new_capacity) + + def _expand(self, new_capacity: int) -> None: + if new_capacity <= self.capacity: + return + + new_buffer = bytearray(new_capacity) + if self.size > 0: + if self.tail < self.head: + new_buffer[0:self.size] = self.buffer[self.tail:self.head] + else: + first_part = self.capacity - self.tail + new_buffer[0:first_part] = self.buffer[self.tail:self.capacity] + new_buffer[first_part:self.size] = self.buffer[0:self.head] + + self.buffer = new_buffer + self.capacity = new_capacity + self.tail = 0 + self.head = self.size + + def write(self, data: bytes) -> int: + data_len = len(data) + if data_len == 0: + return 0 + + self._ensure_capacity(data_len) + write_len = min(data_len, self.capacity - self.size) + + if self.head + write_len <= self.capacity: + self.buffer[self.head:self.head + write_len] = data[:write_len] + self.head = (self.head + write_len) % self.capacity + else: + first_part = self.capacity - self.head + self.buffer[self.head:self.capacity] = data[:first_part] + second_part = write_len - first_part + self.buffer[0:second_part] = data[first_part:first_part + second_part] + self.head = second_part + + self.size += write_len + return write_len + + def read(self, length: int) -> bytearray: + if length <= 0 or self.size == 0: + return bytearray() + + read_len = min(length, self.size) + result = bytearray(read_len) + + if self.tail + read_len <= self.capacity: + result[:] = self.buffer[self.tail:self.tail + read_len] + self.tail = (self.tail + read_len) % self.capacity + else: + first_part = self.capacity - self.tail + result[0:first_part] = self.buffer[self.tail:self.capacity] + second_part = read_len - first_part + result[first_part:read_len] = self.buffer[0:second_part] + self.tail = second_part + + self.size -= read_len + return result + + def peek(self, length: int) -> bytearray: + if length <= 0 or self.size == 0: + return bytearray() + + read_len = min(length, self.size) + result = bytearray(read_len) + + if self.tail + read_len <= self.capacity: + result[:] = self.buffer[self.tail:self.tail + read_len] + else: + first_part = self.capacity - self.tail + result[0:first_part] = self.buffer[self.tail:self.capacity] + second_part = read_len - first_part + result[first_part:read_len] = self.buffer[0:second_part] + + return result + + def discard(self, length: int) -> int: + if length <= 0 or self.size == 0: + return 0 + + discard_len = min(length, self.size) + self.tail = (self.tail + discard_len) % self.capacity + self.size -= discard_len + return discard_len + + def find(self, pattern: bytes, start: int = 0) -> int: + if not pattern or self.size == 0 or start >= self.size: + return -1 + + pattern_len = len(pattern) + if pattern_len == 0: + return start + + for i in range(start, self.size - pattern_len + 1): + match = True + for j in range(pattern_len): + pos = (self.tail + i + j) % self.capacity + if self.buffer[pos] != pattern[j]: + match = False + break + if match: + return i + + return -1 + + def clear(self) -> None: + self.buffer = bytearray(self.initial_capacity) + self.capacity = self.initial_capacity + self.head = 0 + self.tail = 0 + self.size = 0 class HmClient: - """harmony uitest client""" + """Harmony OS 设备通信客户端""" + def __init__(self, serial: str): self.hdc = HdcWrapper(serial) - self.sock = None + self.sock: socket.socket | None = None + self.recv_buffer = MessageCircularBuffer() @cached_property - def local_port(self): - fports = self.hdc.list_fport() - logger.debug(fports) if fports else None - - return self.hdc.forward_port(UITEST_SERVICE_PORT) + def local_port(self) -> int: + return self.hdc.forward_port(SocketConfig.PORT) - def _rm_local_port(self): - logger.debug("rm fport local port") - self.hdc.rm_forward(self.local_port, UITEST_SERVICE_PORT) + def _rm_local_port(self) -> None: + self.hdc.rm_forward(self.local_port, SocketConfig.PORT) - def _connect_sock(self): - """Create socket and connect to the uiTEST server.""" + def _connect_sock(self) -> None: self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.sock.settimeout(SOCKET_TIMEOUT) - self.sock.connect((("127.0.0.1", self.local_port))) - - def _send_msg(self, msg: typing.Dict): - """Send an message to the server. - Example: - { - "module": "com.ohos.devicetest.hypiumApiHelper", - "method": "callHypiumApi", - "params": { - "api": "Driver.create", - "this": null, - "args": [], - "message_type": "hypium" - }, - "request_id": "20240815161352267072", - "client": "127.0.0.1" - } - """ - msg = json.dumps(msg, ensure_ascii=False, separators=(',', ':')) - logger.debug(f"sendMsg: {msg}") - self.sock.sendall(msg.encode('utf-8') + b'\n') - - def _recv_msg(self, buff_size: int = 4096, decode=False, print=True) -> typing.Union[bytearray, str]: - full_msg = bytearray() - try: - # FIXME - relay = self.sock.recv(buff_size) - if decode: - relay = relay.decode() - if print: - logger.debug(f"recvMsg: {relay}") - full_msg = relay - - except (socket.timeout, UnicodeDecodeError) as e: - logger.warning(e) - if decode: - full_msg = "" - - return full_msg - - def invoke(self, api: str, this: str = "Driver#0", args: typing.List = []) -> HypiumResponse: - """ - Hypium invokes given API method with the specified arguments and handles exceptions. - - Args: - api (str): The name of the API method to invoke. - args (List, optional): A list of arguments to pass to the API method. Default is an empty list. - - Returns: - HypiumResponse: The response from the API call. - - Raises: - InvokeHypiumError: If the API call returns an exception in the response. - """ - - request_id = datetime.now().strftime("%Y%m%d%H%M%S%f") - params = { - "api": api, - "this": this, - "args": args, - "message_type": "hypium" - } + self.sock.settimeout(SocketConfig.TIMEOUT) + self.sock.connect((LOCAL_HOST, self.local_port)) + self.recv_buffer.clear() - msg = { - "module": "com.ohos.devicetest.hypiumApiHelper", - "method": "callHypiumApi", - "params": params, - "request_id": request_id - } + def _send_msg(self, msg: dict) -> None: + msg_str = json.dumps(msg, ensure_ascii=False, separators=(',', ':')) + msg_bytes = msg_str.encode('utf-8') - self._send_msg(msg) - raw_data = self._recv_msg(decode=True) - data = HypiumResponse(**(json.loads(raw_data))) - if data.exception: - raise InvokeHypiumError(data.exception) - return data + session_id = self._generate_session_id(msg_str) + header = ( + MessageProtocol.HEADER + + struct.pack('>I', session_id) + + struct.pack('>I', len(msg_bytes)) + ) - def invoke_captures(self, api: str, args: typing.List = []) -> HypiumResponse: - request_id = datetime.now().strftime("%Y%m%d%H%M%S%f") - params = { - "api": api, - "args": args - } + if self.sock is None: + raise ConnectionError("Socket 未连接") + + self.sock.sendall(header + msg_bytes + MessageProtocol.TAILER) - msg = { - "module": "com.ohos.devicetest.hypiumApiHelper", - "method": "Captures", + @staticmethod + def _generate_session_id(message: str) -> int: + combined = str(int(time.time() * 1000)) + message + os.urandom(4).hex() + return struct.unpack('>I', hashlib.sha256(combined.encode()).digest()[:4])[0] | 0x80000000 + + def _recv_msg(self, decode: bool = False) -> bytearray | str: + try: + while True: + result = self._try_parse_message() + if result is not None: + return result.decode('utf-8') if decode else result + + if self.sock is None: + raise ConnectionError("Socket 未连接") + + chunk = self.sock.recv(SocketConfig.BUFFER_SIZE) + if not chunk: + raise ConnectionError("连接已关闭") + + self.recv_buffer.write(chunk) + + except (socket.timeout, ValueError, json.JSONDecodeError) as e: + print(f"接收消息时出错: {e}") + return bytearray() if not decode else "" + + def _try_parse_message(self) -> bytearray | None: + """尝试从缓冲区中解析一个完整消息""" + MAX_MESSAGE_SIZE = 5 * 1024 * 1024 + + while True: + header_pos = self.recv_buffer.find(MessageProtocol.HEADER) + if header_pos == -1: + # 没有找到消息头,保留数据等待更多数据 + # 但需要防止缓冲区被填满,如果数据量过大可能不是合法数据 + if self.recv_buffer.size > MessageProtocol.FULL_HEADER_LENGTH * 3: + keep_size = MessageProtocol.FULL_HEADER_LENGTH * 2 + # 丢弃部分数据,避免缓冲区溢出 + discard_len = self.recv_buffer.size - keep_size + if discard_len > 0: + self.recv_buffer.discard(discard_len) + return None + # 丢弃头部之前的数据 + if header_pos > 0: + self.recv_buffer.discard(header_pos) + # 检查是否有完整的消息头 + if self.recv_buffer.size < MessageProtocol.FULL_HEADER_LENGTH: + return None + # 提取消息头 + header_data = self.recv_buffer.peek(MessageProtocol.FULL_HEADER_LENGTH) + # 验证消息头格式 + if header_data[:MessageProtocol.HEADER_LENGTH] != MessageProtocol.HEADER: + # 无效的消息头,丢弃一个字节后重试 + self.recv_buffer.discard(1) + continue + + length_pos = MessageProtocol.HEADER_LENGTH + MessageProtocol.SESSION_ID_LENGTH + msg_length = struct.unpack('>I', header_data[length_pos:length_pos + 4])[0] + + if msg_length > MAX_MESSAGE_SIZE or msg_length == 0: + print(f"无效的消息长度: {msg_length},丢弃数据") + self.recv_buffer.discard(MessageProtocol.FULL_HEADER_LENGTH) + continue + # 计算完整消息长度 + full_msg_length = MessageProtocol.FULL_HEADER_LENGTH + msg_length + MessageProtocol.TAILER_LENGTH + # 检查是否有完整消息 + if self.recv_buffer.size < full_msg_length: + return None + # 提取完整消息 + full_msg = self.recv_buffer.read(full_msg_length) + # 验证消息尾 + tail_start = MessageProtocol.FULL_HEADER_LENGTH + msg_length + tailer = full_msg[tail_start:tail_start + MessageProtocol.TAILER_LENGTH] + + if tailer != MessageProtocol.TAILER: + print("消息尾不匹配,可能数据损坏") + next_header_pos = self.recv_buffer.find(MessageProtocol.HEADER) + if next_header_pos != -1: + self.recv_buffer.discard(next_header_pos) + continue + + return full_msg[MessageProtocol.FULL_HEADER_LENGTH:tail_start] + + @staticmethod + def _build_request(method: str, api: str, args: list, this=None, message_type=None) -> dict: + params = {"api": api, "args": args} + if this is not None: + params["this"] = this + + if message_type is not None: + params["message_type"] = message_type + + return { + "module": API_MODULE, + "method": method, "params": params, - "request_id": request_id + "request_id": datetime.now().strftime("%Y%m%d%H%M%S%f") } + def _invoke_common(self, method: str, api: str, args: list | None, this: str | None, message_type, + exception_class) -> HypiumResponse: + if args is None: + args = [] + + msg = self._build_request(method, api, args, this, message_type) self._send_msg(msg) + raw_data = self._recv_msg(decode=True) - data = HypiumResponse(**(json.loads(raw_data))) + if not raw_data: + raise exception_class("接收响应失败") + + try: + data = HypiumResponse(**(json.loads(raw_data))) + except json.JSONDecodeError as e: + raise exception_class(f"解析响应失败: {e}") + if data.exception: - raise InvokeCaptures(data.exception) + raise exception_class(data.exception) return data - def start(self): - logger.info("Start HmClient connection") - self._init_so_resource() - self._restart_uitest_service() + def invoke(self, api: str, this: str | None = DEFAULT_THIS, args: list | None = None) -> HypiumResponse: + return self._invoke_common(API_METHOD_HYPIUM, api, args, this, "hypium", InvokeHypiumError) - self._connect_sock() + def invoke_captures(self, api: str, args: list | None = None) -> HypiumResponse: + return self._invoke_common(API_METHOD_CAPTURES, api, args, None, None, InvokeCaptures) - self._create_hdriver() + def start(self) -> None: + _UITestService(self.hdc).init() + self._connect_sock() - def release(self): - logger.info(f"Release {self.__class__.__name__} connection") + def release(self) -> None: try: if self.sock: self.sock.close() self.sock = None - self._rm_local_port() - + self.recv_buffer.clear() except Exception as e: - logger.error(f"An error occurred: {e}") - - def _create_hdriver(self) -> DriverData: - logger.debug("create uitest driver") - resp: HypiumResponse = self.invoke("Driver.create") # {"result":"Driver#0"} - hdriver: DriverData = DriverData(resp.result) - return hdriver - - def _init_so_resource(self): - "Initialize the agent.so resource on the device." - - logger.debug("init the agent.so resource on the device.") - - def __get_so_local_path() -> str: - current_path = os.path.realpath(__file__) - return os.path.join(os.path.dirname(current_path), "assets", "uitest_agent_v1.1.0.so") - - def __check_device_so_file_exists() -> bool: - """Check if the agent.so file exists on the device.""" - command = "[ -f /data/local/tmp/agent.so ] && echo 'so exists' || echo 'so not exists'" - result = self.hdc.shell(command).output.strip() - return "so exists" in result - - def __get_remote_md5sum() -> str: - """Get the MD5 checksum of the file on the device.""" - command = "md5sum /data/local/tmp/agent.so" - data = self.hdc.shell(command).output.strip() - return data.split()[0] - - def __get_local_md5sum(f: str) -> str: - """Calculate the MD5 checksum of a local file.""" - hash_md5 = hashlib.md5() - with open(f, "rb") as f: - for chunk in iter(lambda: f.read(4096), b""): - hash_md5.update(chunk) - return hash_md5.hexdigest() - - local_path = __get_so_local_path() - remote_path = "/data/local/tmp/agent.so" - - if __check_device_so_file_exists() and __get_local_md5sum(local_path) == __get_remote_md5sum(): - return + print(f"释放资源时出错: {e}") + + +class _UITestService: + """UITest 服务管理类""" + + def __init__(self, hdc: HdcWrapper): + self.hdc = hdc + self._remote_agent_path = "/data/local/tmp/agent.so" + + def init(self) -> None: + local_path = self._get_local_agent_path() + self._kill_uitest_service() + self._setup_device_agent(local_path, self._remote_agent_path) + self._start_uitest_daemon() + time.sleep(0.5) + + def _get_local_agent_path(self) -> str: + target_agent = "uitest_agent_v1.1.7.so" + return os.path.join(os.path.dirname(os.path.realpath(__file__)), "assets", target_agent) + + def _get_remote_md5sum(self, file_path: str) -> str | None: + output = self.hdc.shell(f"md5sum {file_path}").output.strip() + return output.split()[0] if output else None + + def _get_local_md5sum(self, file_path: str) -> str: + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + def _is_remote_file_exists(self, file_path: str) -> bool: + result = self.hdc.shell(f"[ -f {file_path} ] && echo 'exists' || echo 'not exists'").output.strip() + return "exists" in result + + def _setup_device_agent(self, local_path: str, remote_path: str) -> None: + if self._is_remote_file_exists(remote_path): + local_md5 = self._get_local_md5sum(local_path) + remote_md5 = self._get_remote_md5sum(remote_path) + if local_md5 == remote_md5: + self.hdc.shell(f"chmod +x {remote_path}") + return + self.hdc.shell(f"rm {remote_path}") + self.hdc.send_file(local_path, remote_path) self.hdc.shell(f"chmod +x {remote_path}") - def _restart_uitest_service(self): - """ - Restart the UITest daemon. + def _get_uitest_pid(self) -> list: + result = self.hdc.shell("pgrep -f 'uitest start-daemon singleness'").output.strip() + return result.splitlines() if result else [] - Note: 'hdc shell aa test' will also start a uitest daemon. - $ hdc shell ps -ef |grep uitest - shell 44306 1 25 11:03:37 ? 00:00:16 uitest start-daemon singleness - shell 44416 1 2 11:03:42 ? 00:00:01 uitest start-daemon com.hmtest.uitest@4x9@1" - """ - try: - result = self.hdc.shell("ps -ef").output.strip() - lines = result.splitlines() - filtered_lines = [line for line in lines if 'uitest' in line and 'singleness' in line] - - for line in filtered_lines: - if 'uitest start-daemon singleness' in line: - parts = line.split() - pid = parts[1] - self.hdc.shell(f"kill -9 {pid}") - logger.debug(f"Killed uitest process with PID {pid}") - - self.hdc.shell("uitest start-daemon singleness") - time.sleep(.5) - - except subprocess.CalledProcessError as e: - logger.error(f"An error occurred: {e}") + def _kill_uitest_service(self) -> None: + for pid in self._get_uitest_pid(): + self.hdc.shell(f"kill -9 {pid}") + + def _start_uitest_daemon(self) -> None: + self.hdc.shell("uitest start-daemon singleness") diff --git a/hmdriver2/_gesture.py b/hmdriver2/_gesture.py index d81d6b5..ddd2c51 100644 --- a/hmdriver2/_gesture.py +++ b/hmdriver2/_gesture.py @@ -2,7 +2,6 @@ import math from typing import List, Union -from . import logger from .utils import delay from .driver import Driver from .proto import HypiumResponse, Point @@ -95,7 +94,6 @@ def action(self): """ Execute the gesture action. """ - logger.info(f">>>Gesture steps: {self.steps}") total_points = self._calculate_total_points() pointer_matrix = self._create_pointer_matrix(total_points) @@ -140,7 +138,7 @@ def _add_step(self, x: int, y: int, step_type: str, interval: float): step_type (str): Type of step ("start", "move", or "pause"). interval (float): Interval duration in seconds. """ - point: Point = self.d._to_abs_pos(x, y) + point: Point = self.d.to_abs_pos(x, y) step = GestureStep(point.to_tuple(), step_type, interval) self.steps.append(step) diff --git a/hmdriver2/_screenrecord.py b/hmdriver2/_screenrecord.py index 4cf56ab..b8f2170 100644 --- a/hmdriver2/_screenrecord.py +++ b/hmdriver2/_screenrecord.py @@ -1,36 +1,70 @@ # -*- coding: utf-8 -*- -import typing +import queue import threading -import numpy as np -from queue import Queue from datetime import datetime +from typing import List, Optional, Any import cv2 +import numpy as np from . import logger from ._client import HmClient from .driver import Driver from .exception import ScreenRecordError +# 常量定义 +JPEG_START_FLAG = b'\xff\xd8' # JPEG 图像开始标记 +JPEG_END_FLAG = b'\xff\xd9' # JPEG 图像结束标记 +VIDEO_FPS = 10 # 视频帧率 +VIDEO_CODEC = 'mp4v' # 视频编码格式 +QUEUE_TIMEOUT = 0.1 # 队列超时时间(秒) + class RecordClient(HmClient): + """ + 屏幕录制客户端 + + 继承自 HmClient,提供设备屏幕录制功能 + """ + def __init__(self, serial: str, d: Driver): + """ + 初始化屏幕录制客户端 + + Args: + serial: 设备序列号 + d: Driver 实例 + """ super().__init__(serial) self.d = d - self.video_path = None - self.jpeg_queue = Queue() - self.threads: typing.List[threading.Thread] = [] - self.stop_event = threading.Event() + self.video_path: Optional[str] = None + self.jpeg_queue: queue.Queue = queue.Queue() + self.threads: List[threading.Thread] = [] + self.stop_event: threading.Event = threading.Event() def __enter__(self): + """上下文管理器入口""" return self def __exit__(self, exc_type, exc_val, exc_tb): + """上下文管理器退出时停止录制""" self.stop() - def _send_msg(self, api: str, args: list): + def _send_msg(self, api: str, args: Optional[List[Any]] = None): + """ + 发送消息到设备 + + 重写父类方法,使用 Captures API + + Args: + api: API 名称 + args: API 参数列表,默认为空列表 + """ + if args is None: + args = [] + _msg = { "module": "com.ohos.devicetest.hypiumApiHelper", "method": "Captures", @@ -43,16 +77,32 @@ def _send_msg(self, api: str, args: list): super()._send_msg(_msg) def start(self, video_path: str): - logger.info("Start RecordClient connection") - + """ + 开始屏幕录制 + + Args: + video_path: 视频保存路径 + + Returns: + RecordClient: 当前实例,支持链式调用 + + Raises: + ScreenRecordError: 启动屏幕录制失败时抛出 + """ + logger.info("开始屏幕录制") + + # 连接设备 self._connect_sock() self.video_path = video_path + # 发送开始录制命令 self._send_msg("startCaptureScreen", []) - reply: str = self._recv_msg(1024, decode=True, print=False) + # 检查响应 + reply: str = self._recv_msg(decode=True, print=False) if "true" in reply: + # 创建并启动工作线程 record_th = threading.Thread(target=self._record_worker) writer_th = threading.Thread(target=self._video_writer) record_th.daemon = True @@ -61,69 +111,100 @@ def start(self, video_path: str): writer_th.start() self.threads.extend([record_th, writer_th]) else: - raise ScreenRecordError("Failed to start device screen capture.") + raise ScreenRecordError("启动设备屏幕录制失败") return self def _record_worker(self): - """Capture screen frames and save current frames.""" - - # JPEG start and end markers. - start_flag = b'\xff\xd8' - end_flag = b'\xff\xd9' + """ + 屏幕帧捕获工作线程 + + 捕获屏幕帧并保存当前帧 + """ buffer = bytearray() while not self.stop_event.is_set(): try: - buffer += self._recv_msg(4096 * 1024, decode=False, print=False) + buffer += self._recv_msg(decode=False, print=False) except Exception as e: - print(f"Error receiving data: {e}") + logger.error(f"接收数据时出错: {e}") break - start_idx = buffer.find(start_flag) - end_idx = buffer.find(end_flag) + # 查找 JPEG 图像的开始和结束标记 + start_idx = buffer.find(JPEG_START_FLAG) + end_idx = buffer.find(JPEG_END_FLAG) + + # 处理所有完整的 JPEG 图像 while start_idx != -1 and end_idx != -1 and end_idx > start_idx: - # Extract one JPEG image + # 提取一个 JPEG 图像 jpeg_image: bytearray = buffer[start_idx:end_idx + 2] self.jpeg_queue.put(jpeg_image) + # 从缓冲区中移除已处理的数据 buffer = buffer[end_idx + 2:] - # Search for the next JPEG image in the buffer - start_idx = buffer.find(start_flag) - end_idx = buffer.find(end_flag) + # 在缓冲区中查找下一个 JPEG 图像 + start_idx = buffer.find(JPEG_START_FLAG) + end_idx = buffer.find(JPEG_END_FLAG) def _video_writer(self): - """Write frames to video file.""" + """ + 视频写入工作线程 + + 将帧写入视频文件 + """ cv2_instance = None + img = None while not self.stop_event.is_set(): - if not self.jpeg_queue.empty(): - jpeg_image = self.jpeg_queue.get(timeout=0.1) + try: + # 从队列获取 JPEG 图像 + jpeg_image = self.jpeg_queue.get(timeout=QUEUE_TIMEOUT) img = cv2.imdecode(np.frombuffer(jpeg_image, np.uint8), cv2.IMREAD_COLOR) - if cv2_instance is None: - height, width = img.shape[:2] - fourcc = cv2.VideoWriter_fourcc(*'mp4v') - cv2_instance = cv2.VideoWriter(self.video_path, fourcc, 10, (width, height)) - - cv2_instance.write(img) - + except queue.Empty: + pass + + # 跳过无效图像 + if img is None or img.size == 0: + continue + + # 首次获取有效图像时初始化视频写入器 + if cv2_instance is None: + height, width = img.shape[:2] + fourcc = cv2.VideoWriter_fourcc(*VIDEO_CODEC) + cv2_instance = cv2.VideoWriter(self.video_path, fourcc, VIDEO_FPS, (width, height)) + + # 写入帧 + cv2_instance.write(img) + + # 释放资源 if cv2_instance: cv2_instance.release() def stop(self) -> str: + """ + 停止屏幕录制 + + Returns: + str: 视频保存路径 + """ try: + # 设置停止事件,通知工作线程退出 self.stop_event.set() + + # 等待所有工作线程结束 for t in self.threads: t.join() + # 发送停止录制命令 self._send_msg("stopCaptureScreen", []) - self._recv_msg(1024, decode=True, print=False) + self._recv_msg(decode=True, print=False) + # 释放资源 self.release() - # Invalidate the cached property + # 使缓存的属性失效 self.d._invalidate_cache('screenrecord') except Exception as e: - logger.error(f"An error occurred: {e}") + logger.error(f"停止屏幕录制时出错: {e}") return self.video_path diff --git a/hmdriver2/_swipe.py b/hmdriver2/_swipe.py index aab92a0..a080734 100644 --- a/hmdriver2/_swipe.py +++ b/hmdriver2/_swipe.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- - from typing import Union, Tuple - from .driver import Driver from .proto import SwipeDirection @@ -77,8 +75,8 @@ def _validate_and_convert_box(self, box: Tuple) -> Tuple[int, int, int, int]: raise ValueError("Box coordinates must satisfy x1 < x2 and y1 < y2.") from .driver import Point - p1: Point = self._d._to_abs_pos(x1, y1) - p2: Point = self._d._to_abs_pos(x2, y2) + p1: Point = self._d.to_abs_pos(x1, y1) + p2: Point = self._d.to_abs_pos(x2, y2) x1, y1, x2, y2 = p1.x, p1.y, p2.x, p2.y return x1, y1, x2, y2 diff --git a/hmdriver2/_uiobject.py b/hmdriver2/_uiobject.py index 31f58d0..41f912a 100644 --- a/hmdriver2/_uiobject.py +++ b/hmdriver2/_uiobject.py @@ -1,11 +1,7 @@ # -*- coding: utf-8 -*- - import enum import time -from typing import List, Union - -from . import logger -from .utils import delay +from typing import List, Union, Optional, Any from ._client import HmClient from .exception import ElementNotFoundError from .proto import ComponentData, ByData, HypiumResponse, Point, Bounds, ElementInfo @@ -25,168 +21,109 @@ class ByType(enum.Enum): selected = "selected" checked = "checked" checkable = "checkable" - isBefore = "isBefore" - isAfter = "isAfter" + isBefore = "isBefore" # 表示找前一个元素 + isAfter = "isAfter" # 表示找后一个元素 @classmethod def verify(cls, value): return any(value == item.value for item in cls) -class UiObject: - DEFAULT_TIMEOUT = 2 +class UiElement: + """表示单个UI元素,封装所有元素操作""" - def __init__(self, client: HmClient, **kwargs) -> None: + def __init__(self, client: HmClient, component: ComponentData) -> None: self._client = client - self._raw_kwargs = kwargs - - self._index = kwargs.pop("index", 0) - self._isBefore = kwargs.pop("isBefore", False) - self._isAfter = kwargs.pop("isAfter", False) - - self._kwargs = kwargs - self.__verify() - - self._component: Union[ComponentData, None] = None # cache - - def __str__(self) -> str: - return f"UiObject [{self._raw_kwargs}" - - def __verify(self): - for k, v in self._kwargs.items(): - if not ByType.verify(k): - raise ReferenceError(f"{k} is not allowed.") - - @property - def count(self) -> int: - eleements = self.__find_components() - return len(eleements) if eleements else 0 - - def __len__(self): - return self.count - - def exists(self, retries: int = 2, wait_time=1) -> bool: - obj = self.find_component(retries, wait_time) - return True if obj else False - - def __set_component(self, component: ComponentData): self._component = component - - def find_component(self, retries: int = 1, wait_time=1) -> ComponentData: - for attempt in range(retries): - components = self.__find_components() - if components and self._index < len(components): - self.__set_component(components[self._index]) - return self._component - - if attempt < retries: - time.sleep(wait_time) - logger.info(f"Retry found element {self}") - - return None - - # useless - def __find_component(self) -> Union[ComponentData, None]: - by: ByData = self.__get_by() - resp: HypiumResponse = self._client.invoke("Driver.findComponent", args=[by.value]) - if not resp.result: - return None - return ComponentData(resp.result) - - def __find_components(self) -> Union[List[ComponentData], None]: - by: ByData = self.__get_by() - resp: HypiumResponse = self._client.invoke("Driver.findComponents", args=[by.value]) - if not resp.result: - return None - components: List[ComponentData] = [] - for item in resp.result: - components.append(ComponentData(item)) - - return components - - def __get_by(self) -> ByData: - for k, v in self._kwargs.items(): - api = f"On.{k}" - this = "On#seed" - resp: HypiumResponse = self._client.invoke(api, this, args=[v]) - this = resp.result - - if self._isBefore: - resp: HypiumResponse = self._client.invoke("On.isBefore", this="On#seed", args=[resp.result]) - - if self._isAfter: - resp: HypiumResponse = self._client.invoke("On.isAfter", this="On#seed", args=[resp.result]) - - return ByData(resp.result) - - def __operate(self, api, args=[], retries: int = 2): - if not self._component: - if not self.find_component(retries): - raise ElementNotFoundError(f"Element({self}) not found after {retries} retries") - + self._last_check_time = 0 # 记录最后检查时间 + self._cached_state = None # 缓存元素状态 + self._state_cache = {} + self._cache_expiry = 0 + + def __operate(self, api, args=None): + if args is None: + args = [] resp: HypiumResponse = self._client.invoke(api, this=self._component.value, args=args) return resp.result + def _get_cached_property(self, prop_name: str) -> Any: + """带缓存的属性获取""" + current_time = time.perf_counter() + if current_time - self._cache_expiry < UiObject.CACHE_TTL: + return self._state_cache.get(prop_name) + + # 刷新整个状态缓存 + self._state_cache = { + key: self.__operate(f"Component.{key}") + for key in ( + "isSelected", "isChecked", "isEnabled", + "isFocused", "isClickable", "isLongClickable" + ) + } + self._cache_expiry = current_time + return self._state_cache.get(prop_name) + @property def id(self) -> str: - return self.__operate("Component.getId") + return self._get_cached_property("getId") @property def key(self) -> str: - return self.__operate("Component.getId") + return self._get_cached_property("getId") @property def type(self) -> str: - return self.__operate("Component.getType") + return self._get_cached_property("getType") @property def text(self) -> str: - return self.__operate("Component.getText") + return self._get_cached_property("getText") @property def description(self) -> str: - return self.__operate("Component.getDescription") + return self._get_cached_property("getDescription") @property def isSelected(self) -> bool: - return self.__operate("Component.isSelected") + + return self._get_cached_property("isSelected") @property def isChecked(self) -> bool: - return self.__operate("Component.isChecked") + return self._get_cached_property("isChecked") @property def isEnabled(self) -> bool: - return self.__operate("Component.isEnabled") + return self._get_cached_property("isEnabled") @property def isFocused(self) -> bool: - return self.__operate("Component.isFocused") + return self._get_cached_property("isFocused") @property def isCheckable(self) -> bool: - return self.__operate("Component.isCheckable") + return self._get_cached_property("isCheckable") @property def isClickable(self) -> bool: - return self.__operate("Component.isClickable") + return self._get_cached_property("isClickable") @property def isLongClickable(self) -> bool: - return self.__operate("Component.isLongClickable") + return self._get_cached_property("isLongClickable") @property def isScrollable(self) -> bool: - return self.__operate("Component.isScrollable") + return self._get_cached_property("isScrollable") @property def bounds(self) -> Bounds: - _raw = self.__operate("Component.getBounds") + _raw = self._get_cached_property("getBounds") return Bounds(**_raw) @property def boundsCenter(self) -> Point: - _raw = self.__operate("Component.getBoundsCenter") + _raw = self._get_cached_property("getBoundsCenter") return Point(**_raw) @property @@ -208,41 +145,272 @@ def info(self) -> ElementInfo: bounds=self.bounds, boundsCenter=self.boundsCenter) - @delay def click(self): + print(111) return self.__operate("Component.click") - @delay - def click_if_exists(self): - try: - return self.__operate("Component.click") - except ElementNotFoundError: - pass - - @delay def double_click(self): return self.__operate("Component.doubleClick") - @delay def long_click(self): return self.__operate("Component.longClick") - @delay - def drag_to(self, component: ComponentData): - return self.__operate("Component.dragTo", [component.value]) + def drag_to(self, element: 'UiElement'): + return self.__operate("Component.dragTo", [element._component.value]) - @delay def input_text(self, text: str): return self.__operate("Component.inputText", [text]) - @delay def clear_text(self): return self.__operate("Component.clearText") - @delay def pinch_in(self, scale: float = 0.5): return self.__operate("Component.pinchIn", [scale]) - @delay def pinch_out(self, scale: float = 2): return self.__operate("Component.pinchOut", [scale]) + + +class UiElementList: + """管理多个UI元素的集合类""" + + def __init__(self, elements: List[UiElement]) -> None: + self.elements = elements + + def __len__(self): + return len(self.elements) + + def __getitem__(self, index): + return self.elements[index] + + def __iter__(self): + return iter(self.elements) + + @property + def first(self) -> Optional[UiElement]: + return self.elements[0] if self.elements else None + + @property + def last(self) -> Optional[UiElement]: + return self.elements[-1] if self.elements else None + + def click_all(self): + for element in self.elements: + element.click() + + def get_by_index(self, index: int) -> Optional[UiElement]: + if 0 <= index < len(self.elements): + return self.elements[index] + return None + + +class UiObject: + DEFAULT_TIMEOUT = 2 + POLL_INTERVAL = 0.001 # 10ms轮询间隔 + CACHE_TTL = 0.05 # 缓存有效期200ms + + def __init__(self, client: HmClient, **kwargs) -> None: + self._client = client + self._raw_kwargs = kwargs + self._isBefore = kwargs.pop("isBefore", False) + self._isAfter = kwargs.pop("isAfter", False) + self._kwargs = kwargs + self.__verify() + self._cache = None # 查找结果缓存 + self._cache_time = 0 # 缓存时间戳 + self._by_cache = None # ByData缓存 + + def __str__(self) -> str: + return f"UiObject [{self._raw_kwargs}" + + def __verify(self): + for k, v in self._kwargs.items(): + if not ByType.verify(k): + raise ReferenceError(f"{k} is not allowed.") + + def _invalidate_cache(self): + """使缓存失效""" + self._cache = None + self._cache_time = 0 + self._by_cache = None + + @property + def count(self) -> int: + elements = self.find_components() + return len(elements) if elements else 0 + + def __len__(self): + return self.count + + def exists(self, timeout=0) -> bool: + """检查元素是否存在,支持即时检测""" + return len(self.find_components(timeout, use_cache=False)) > 0 + + def wait(self, exists=True, timeout=1) -> bool: + """等待元素出现或消失,延迟控制在10ms以内""" + start_time = time.perf_counter() + last_state = self.exists(0) + + # 初始状态检查 + if exists == last_state: + return exists + + # 优化后的轮询逻辑 + while time.perf_counter() - start_time < timeout: + current_state = self.exists(0) + # 状态变化立即返回 + if current_state != last_state: + if exists == current_state: + return True + last_state = current_state + + time.sleep(self.POLL_INTERVAL) + + # 超时后最终检查 + return exists == self.exists(0) + + def _poll_for_elements(self, timeout) -> list: + """封装轮询逻辑""" + start_time = time.perf_counter() + while True: + if components := self.__find_components(): + return [UiElement(self._client, comp) for comp in components] + + # 超时或无需轮询时退出 + if timeout <= 0 or (time.perf_counter() - start_time) >= timeout: + return [] + + time.sleep(self.POLL_INTERVAL) + + def find_components(self, timeout=DEFAULT_TIMEOUT, use_cache=True) -> UiElementList: + """改进的查找方法,支持禁用缓存""" + current_time = time.perf_counter() + + # 缓存有效且启用缓存时直接返回 + if use_cache and self._cache and current_time - self._cache_time < self.CACHE_TTL: + return UiElementList(self._cache) + + # 直接查找或轮询逻辑 + elements = self._poll_for_elements(timeout) + + # 更新缓存 + self._cache = elements + self._cache_time = time.perf_counter() + return UiElementList(elements) + + def __find_components(self) -> Union[List[ComponentData], None]: + """查找组件并缓存ByData""" + if self._by_cache: + by = self._by_cache + else: + by = self.__get_by() + self._by_cache = by + + resp: HypiumResponse = self._client.invoke("Driver.findComponents", args=[by.value]) + if not resp.result: + return None + return [ComponentData(item) for item in resp.result] + + def __get_by(self) -> ByData: + """链式构建 ByData 对象""" + seed = None + # 动态构建查询链 + for k, v in self._kwargs.items(): + api = f"On.{k}" + this = seed or "On#seed" + resp = self._client.invoke(api, this, [v]) + seed = resp.result + + # 处理位置关系 + position_actions = { + "isBefore": "On.isBefore", + "isAfter": "On.isAfter" + } + for attr, api in position_actions.items(): + if getattr(self, f"_{attr}", False): + resp = self._client.invoke(api, seed, [seed]) + seed = resp.result + + return ByData(seed) + + # def _perform_action(self, action: str, *args, **kwargs): + # """统一操作调度方法""" + # elements = self.find_components() + # if not elements or not elements.first: + # if kwargs.get("ignore_not_found"): + # return + # raise ElementNotFoundError(f"Element({self}) not found") + # + # # 动态调用 UiElement 方法 + # method = getattr(elements.first, action) + # return method(*args) + + def click(self): + """点击找到的第一个元素""" + elements = self.find_components() + if not elements: + raise ElementNotFoundError(f"Element({self}) not found") + elements.first.click() + + def click_first(self): + """点击找到的第一个元素""" + elements = self.find_components() + if not elements: + raise ElementNotFoundError(f"Element({self}) not found") + elements.first.click() + + def info(self): + """点击找到的第一个元素""" + elements = self.find_components() + if not elements: + raise ElementNotFoundError(f"Element({self}) not found") + return elements.first.info + + def double_click(self): + """点击找到的第一个元素""" + elements = self.find_components() + if not elements: + raise ElementNotFoundError(f"Element({self}) not found") + elements.first.double_click() + + def long_click(self): + """点击找到的第一个元素""" + elements = self.find_components() + if not elements: + raise ElementNotFoundError(f"Element({self}) not found") + elements.first.long_click() + + def drag_to(self, element: 'UiElement'): + """点击找到的第一个元素""" + elements = self.find_components() + if not elements: + raise ElementNotFoundError(f"Element({self}) not found") + elements.first.drag_to(element) + + def input_text(self, text: str): + """点击找到的第一个元素""" + elements = self.find_components() + if not elements: + raise ElementNotFoundError(f"Element({self}) not found") + elements.first.input_text(text) + + def clear_text(self): + """点击找到的第一个元素""" + elements = self.find_components() + if not elements: + raise ElementNotFoundError(f"Element({self}) not found") + elements.first.clear_text() + + def pinch_in(self, scale: float = 0.5): + """点击找到的第一个元素""" + elements = self.find_components() + if not elements: + raise ElementNotFoundError(f"Element({self}) not found") + elements.first.pinch_in(scale) + + def pinch_out(self, scale: float = 2): + """点击找到的第一个元素""" + elements = self.find_components() + if not elements: + raise ElementNotFoundError(f"Element({self}) not found") + elements.first.pinch_out(scale) diff --git a/hmdriver2/_xpath.py b/hmdriver2/_xpath.py index f3ad21e..2666379 100644 --- a/hmdriver2/_xpath.py +++ b/hmdriver2/_xpath.py @@ -1,93 +1,352 @@ # -*- coding: utf-8 -*- - -from typing import Dict +import time +import threading +from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED +from typing import Dict, Any, List, Optional, Union from lxml import etree -from functools import cached_property - -from . import logger -from .proto import Bounds +from .proto import Bounds, ElementInfo, Point from .driver import Driver -from .utils import delay, parse_bounds -from .exception import XmlElementNotFoundError +from .utils import parse_bounds + +# XML相关常量 +XML_ROOT_TAG = "orgRoot" +XML_ATTRIBUTE_TYPE = "type" + +# 布尔属性列表 +BOOL_ATTRIBUTES = ["enabled", "focused", "selected", "checked", "checkable", "clickable", "longClickable", "scrollable"] class _XPath: def __init__(self, d: Driver): self._d = d - def __call__(self, xpath: str) -> '_XMLElement': + def __call__(self, xpath: Union[str, list]) -> '_XPathResult': + hierarchy_dict: dict | None = self._d.dump_hierarchy() + if not hierarchy_dict: + return _XPathResult([], self._d, xpath) + + xml = self._json2xml(hierarchy_dict) + if xml is None: + return _XPathResult([], self._d, xpath) + + if isinstance(xpath, list): + return self._concurrent_xpath_search(xml, xpath) + return self._single_xpath_search(xml, xpath) + + @staticmethod + def _sanitize_text(text: str) -> str: + """快速移除XML不兼容的控制字符""" + return ''.join(ch for ch in text if 31 < ord(ch) < 127) + + def _single_xpath_search(self, xml: etree.Element, xpath: str) -> '_XPathResult': + """处理单个XPath查询""" + try: + results = xml.xpath(xpath) + except etree.XPathError as e: + return _XPathResult([], self._d, xpath) + return _XPathResult([_XMLElement(node, self._d) for node in results], self._d, xpath) + + def _concurrent_xpath_search(self, xml: etree.Element, xpath_list: List[str]) -> '_XPathResult': + """ + 并发查询多个XPath表达式 + - 返回标准的_XPathResult对象 + - 包含第一个找到结果的表达式的元素 + - 在结果对象中保存命中的XPath表达式 + """ + found_event = threading.Event() + result_lock = threading.Lock() + result_elements = [] # 保存找到的元素 + hit_xpath = None # 保存命中的XPath表达式 + + def worker(expr: str): + nonlocal result_elements, hit_xpath + # 检查是否已有结果,避免不必要计算 + if found_event.is_set(): + return + try: + nodes = xml.xpath(expr) + if nodes: + with result_lock: + if not found_event.is_set(): # 双重检查 + # 保存结果和命中的表达式 + result_elements = [_XMLElement(node, self._d) for node in nodes] + hit_xpath = expr + found_event.set() # 通知其他线程终止 + except etree.XPathError: + pass # 忽略单个表达式的语法错误 - hierarchy: Dict = self._d.dump_hierarchy() - if not hierarchy: - raise XmlElementNotFoundError(f"xpath: {xpath} not found") + # 使用线程池提交任务 + with ThreadPoolExecutor(max_workers=min(len(xpath_list), 10)) as executor: + futures = [executor.submit(worker, expr) for expr in xpath_list] - xml = _XPath._json2xml(hierarchy) - result = xml.xpath(xpath) + # 等待首个结果或全部完成 + done, _ = wait(futures, return_when=FIRST_COMPLETED) - if len(result) > 0: - node = result[0] - raw_bounds: str = node.attrib.get("bounds") # [832,1282][1125,1412] - bounds: Bounds = parse_bounds(raw_bounds) - logger.debug(f"{xpath} Bounds: {bounds}") - return _XMLElement(bounds, self._d) + # 如果已有结果则立即取消未启动任务 + if found_event.is_set(): + for future in futures: + if not future.done(): + future.cancel() - return _XMLElement(None, self._d) + # 返回标准化的_XPathResult对象 + return _XPathResult( + elements=result_elements, + d=self._d, + xpath=hit_xpath or " | ".join(xpath_list), # 使用命中的表达式或组合表达式 + matched_xpath=hit_xpath # 新增字段保存命中的具体表达式 + ) @staticmethod - def _json2xml(hierarchy: Dict) -> etree.Element: - attributes = hierarchy.get("attributes", {}) - tag = attributes.get("type", "orgRoot") or "orgRoot" - xml = etree.Element(tag, attrib=attributes) + def _json2xml(hierarchy: Dict[str, Any]) -> etree.Element: + if not isinstance(hierarchy, dict): + return etree.Element(XML_ROOT_TAG) # 空根元素 - children = hierarchy.get("children", []) - for item in children: - xml.append(_XPath._json2xml(item)) - return xml + stack_xml = [(hierarchy, None)] + root = None + while stack_xml: + current_node, parent_node = stack_xml.pop() + if not isinstance(current_node, dict): + print('current_node', current_node) + continue -class _XMLElement: - def __init__(self, bounds: Bounds, d: Driver): - self.bounds = bounds + attributes = current_node.get("attributes", {}) + if not isinstance(attributes, dict): + attributes = {} + + cleaned_attributes = {} + for k, v in attributes.items(): + if k in BOOL_ATTRIBUTES: + cleaned_attributes[k] = v + else: + cleaned_attributes[k] = _XPath._sanitize_text(str(v)) if v is not None else "" + + tag_name = cleaned_attributes.get(XML_ATTRIBUTE_TYPE) or XML_ROOT_TAG + if not isinstance(tag_name, str): + tag_name = XML_ROOT_TAG + + node = etree.Element(tag_name, attrib=cleaned_attributes) + + if parent_node is None: + root = node + else: + parent_node.append(node) + + children = current_node.get("children", []) + if not isinstance(children, list): + children = [] + + valid_children = [] + for child in children: + if child and isinstance(child, dict): + valid_children.append(child) + + for child in reversed(valid_children): + stack_xml.append((child, node)) + + # 最终保障:确保永不返回None + return root if root is not None else etree.Element(XML_ROOT_TAG) + + +class _XPathResult: + __slots__ = ('elements', '_d', '_xpath', '_matched_xpath') + + def __init__(self, elements: List['_XMLElement'], d: Driver, xpath: str, matched_xpath: str = None): + self.elements = elements self._d = d + self._xpath = xpath # 原始查询的XPath(单个或组合) + self._matched_xpath = matched_xpath or xpath # 实际匹配的XPath - def _verify(self): - if not self.bounds: - raise XmlElementNotFoundError("xpath not found") + def find_all(self) -> List['_XMLElement']: + """返回所有匹配的元素""" + return self.elements - @cached_property - def center(self): - self._verify() - return self.bounds.get_center() + @property + def first(self) -> Optional['_XMLElement']: + """返回第一个匹配的元素(如果存在)""" + return self.elements[0] if self.elements else None + + @property + def count(self) -> int: + """返回匹配的元素数量""" + return len(self.elements) def exists(self) -> bool: - return self.bounds is not None + """检查是否存在匹配的元素""" + return len(self.elements) > 0 - @delay - def click(self): - x, y = self.center.x, self.center.y - self._d.click(x, y) + def wait(self, exists=True, timeout=1) -> bool: + start_time = time.perf_counter() + xpath_obj = _XPath(self._d) # 创建新的XPath查询对象 + + while time.perf_counter() - start_time < timeout: + # 每次迭代都重新查询 + current_result = xpath_obj(self._xpath) + current_state = current_result.exists() + + if exists: + if current_state: + # 更新元素列表为当前找到的元素 + self.elements = current_result.elements + return True + else: + if not current_state: + self.elements = [] + return True + # 最终检查 + final_result = xpath_obj(self._xpath) + final_state = final_result.exists() + if exists: + if final_state: + self.elements = final_result.elements + return True + return False + else: + if not final_state: + self.elements = [] + return True + return False + + def click_first(self): + """点击第一个匹配的元素""" + if self.elements: + self.elements[0].click() - @delay - def click_if_exists(self): + def click_all(self): + """点击所有匹配的元素""" + for element in self.elements: + element.click() # 直接点击,元素不存在时会抛出异常 - if not self.exists(): - logger.debug("click_exist: xpath not found") - return + def input_text(self, text: str): + """在第一个匹配的元素输入文本""" + if self.elements: + self.elements[0].input_text(text) - x, y = self.center.x, self.center.y - self._d.click(x, y) + @property + def matched_xpath(self) -> str: + """返回实际匹配的XPath表达式(多查询时有效)""" + return self._matched_xpath + + def __repr__(self) -> str: + return f"" + + def __getitem__(self, index: int) -> '_XMLElement': + """通过索引获取元素""" + return self.elements[index] + + def __iter__(self): + """支持迭代""" + return iter(self.elements) + + def __len__(self) -> int: + """返回元素数量""" + return self.count + + +class _XMLElement: + __slots__ = ('_d', 'attrib', '_bounds', '_center') + + def __init__(self, xpath_node: Any, d: Driver): + self._d = d + self.attrib = xpath_node.attrib + # 边界和中心点 + raw_bounds = self.attrib.get("bounds", "[0,0][0,0]") + self._bounds = parse_bounds(raw_bounds) + self._center = self._bounds.get_center() + + @property + def bounds(self) -> Bounds: + return self._bounds + + @property + def center(self) -> Point: + return self._center + + def click(self): + """点击元素中心点""" + self._d.click(self._center.x, self._center.y) - @delay def double_click(self): - x, y = self.center.x, self.center.y - self._d.double_click(x, y) + """双击元素""" + self._d.double_click(self._center.x, self._center.y) - @delay def long_click(self): - x, y = self.center.x, self.center.y - self._d.long_click(x, y) + """长按元素""" + self._d.long_click(self._center.x, self._center.y) - @delay def input_text(self, text): + """在元素中输入文本""" self.click() - self._d.input_text(text) \ No newline at end of file + self._d.input_text(text) + + @property + def id(self): + return self.attrib.get("id", "") + + @property + def key(self): + return self.attrib.get("key", "") + + @property + def type(self): + return self.attrib.get("type", "") + + @property + def text(self): + return self.attrib.get("text", "") + + @property + def description(self): + return self.attrib.get("description", "") + + @property + def isSelected(self): + return self.attrib.get("selected", "false") == "true" + + @property + def isChecked(self): + return self.attrib.get("checked", "false") == "true" + + @property + def isEnabled(self): + return self.attrib.get("enabled", "false") == "true" + + @property + def isFocused(self): + return self.attrib.get("focused", "false") == "true" + + @property + def isCheckable(self): + return self.attrib.get("checkable", "false") == "true" + + @property + def isClickable(self): + return self.attrib.get("clickable", "false") == "true" + + @property + def isLongClickable(self): + return self.attrib.get("longClickable", "false") == "true" + + @property + def isScrollable(self): + return self.attrib.get("scrollable", "false") == "true" + + @property + def info(self) -> ElementInfo: + return ElementInfo( + id=self.id, + key=self.key, + type=self.type, + text=self.text, + description=self.description, + isSelected=self.isSelected, + isChecked=self.isChecked, + isEnabled=self.isEnabled, + isFocused=self.isFocused, + isCheckable=self.isCheckable, + isClickable=self.isClickable, + isLongClickable=self.isLongClickable, + isScrollable=self.isScrollable, + bounds=self._bounds, + boundsCenter=self._center + ) diff --git a/hmdriver2/assets/uitest_agent_v1.1.7.so b/hmdriver2/assets/uitest_agent_v1.1.7.so new file mode 100644 index 0000000..70198f1 Binary files /dev/null and b/hmdriver2/assets/uitest_agent_v1.1.7.so differ diff --git a/hmdriver2/driver.py b/hmdriver2/driver.py index cde7fbb..b0e7f75 100644 --- a/hmdriver2/driver.py +++ b/hmdriver2/driver.py @@ -1,80 +1,133 @@ # -*- coding: utf-8 -*- - import json +import time import uuid -import re -from typing import Type, Any, Tuple, Dict, Union, List -from functools import cached_property # python3.8+ - -from . import logger +import atexit +from typing import Type, Tuple, Dict, Union, List, Optional +from functools import cached_property +from weakref import WeakValueDictionary from .utils import delay from ._client import HmClient from ._uiobject import UiObject +from .hdc import list_devices +from .exception import DeviceNotFoundError from .proto import HypiumResponse, KeyCode, Point, DisplayRotation, DeviceInfo, CommandResult class Driver: - _instance: Dict = {} - - def __init__(self, serial: str): - self.serial = serial - self._client = HmClient(self.serial) - self.hdc = self._client.hdc + _instance: Dict[str, "Driver"] = WeakValueDictionary() # 改用弱引用字典 + _cleanup_registered = False - self._init_hmclient() - - def __new__(cls: Type[Any], serial: str) -> Any: + def __new__(cls: Type["Driver"], serial: Optional[str] = None) -> "Driver": """ - Ensure that only one instance of Driver exists per device serial number. + Ensure that only one instance of Driver exists per serial. + If serial is None, use the first serial from list_devices(). """ + serial = cls._prepare_serial(serial) + if serial not in cls._instance: - cls._instance[serial] = super().__new__(cls) + instance = super().__new__(cls) + cls._instance[serial] = instance + instance._serial_for_init = serial # 临时存储serial用于初始化 + + # 注册全局清理(仅第一次) + if not cls._cleanup_registered: + atexit.register(cls._global_cleanup) + cls._cleanup_registered = True + return cls._instance[serial] - def __call__(self, **kwargs) -> UiObject: + def __init__(self, serial: Optional[str] = None): + """Initialize only once per instance.""" + if hasattr(self, "_initialized"): + return - return UiObject(self._client, **kwargs) + serial = getattr(self, "_serial_for_init", serial) + if serial is None: + raise ValueError("Serial number is required for initialization.") - def __del__(self): - if hasattr(self, '_client') and self._client: + self.serial = serial + print("开始启动", time.time()) + self._client = HmClient(self.serial) + self._client.start() + print("启动成功", time.time()) + self.hdc = self._client.hdc + self._initialized = True + del self._serial_for_init + + @classmethod + def _global_cleanup(cls): + """Safe cleanup during program exit.""" + for serial, instance in list(cls._instance.items()): + instance.close() + del cls._instance[serial] + + def close(self): + """Explicit resource release.""" + if hasattr(self, "_client") and self._client: self._client.release() + self._client = None - def _init_hmclient(self): - self._client.start() + def __call__(self, **kwargs) -> UiObject: + return UiObject(self._client, **kwargs) - def _invoke(self, api: str, args: List = []) -> HypiumResponse: + def _invoke(self, api: str, args=None) -> HypiumResponse: + if args is None: + args = [] return self._client.invoke(api, this="Driver#0", args=args) - @delay - def start_app(self, package_name: str, page_name: str = "MainAbility"): - self.hdc.start_app(package_name, page_name) + @classmethod + def _prepare_serial(cls, serial: Optional[str]) -> str: + """Validate device serial or auto-select first available.""" + devices = list_devices() + if not devices: + raise DeviceNotFoundError("No devices found.") + + if serial is None: + return devices[0] + if serial not in devices: + raise DeviceNotFoundError(f"Device [{serial}] not found") + return serial + + def app_start(self, package_name: str, page_name: Optional[str] = None): + """ + Start an application on the device. + If the `package_name` is empty, it will retrieve main ability using `get_app_main_ability`. - def force_start_app(self, package_name: str, page_name: str = "MainAbility"): + Args: + package_name (str): The package name of the application. + page_name (Optional[str]): Ability Name within the application to start. + """ + if not page_name: + page_name = self.get_app_main_ability(package_name).get('name', 'MainAbility') + self._client.hdc.app_start(package_name, page_name) + + def force_app_start(self, package_name: str, page_name: Optional[str] = None): self.go_home() - self.stop_app(package_name) - self.start_app(package_name, page_name) + self.app_stop(package_name) + self.app_start(package_name, page_name) - def stop_app(self, package_name: str): - self.hdc.stop_app(package_name) + def app_stop(self, package_name: str): + self._client.hdc.app_stop(package_name) def clear_app(self, package_name: str): """ Clear the application's cache and data. """ - self.hdc.shell(f"bm clean -n {package_name} -c") # clear cache - self.hdc.shell(f"bm clean -n {package_name} -d") # clear data + self._client.hdc.shell(f"bm clean -n {package_name} -c") # clear cache + self._client.hdc.shell(f"bm clean -n {package_name} -d") # clear data def install_app(self, apk_path: str): - self.hdc.install(apk_path) + self._client.hdc.install(apk_path) def uninstall_app(self, package_name: str): - self.hdc.uninstall(package_name) + self._client.hdc.uninstall(package_name) def list_apps(self) -> List: - return self.hdc.list_apps() + return self._client.hdc.list_apps() def has_app(self, package_name: str) -> bool: - return self.hdc.has_app(package_name) + return self._client.hdc.has_app(package_name) def current_app(self) -> Tuple[str, str]: """ @@ -85,7 +138,7 @@ def current_app(self) -> Tuple[str, str]: If no foreground application is found, returns (None, None). """ - return self.hdc.current_app() + return self._client.hdc.current_app() def get_app_info(self, package_name: str) -> Dict: """ @@ -99,7 +152,7 @@ def get_app_info(self, package_name: str) -> Dict: an empty dictionary is returned. """ app_info = {} - data: CommandResult = self.hdc.shell(f"bm dump -n {package_name}") + data: CommandResult = self._client.hdc.shell(f"bm dump -n {package_name}") output = data.output try: json_start = output.find("{") @@ -108,12 +161,76 @@ def get_app_info(self, package_name: str) -> Dict: app_info = json.loads(json_output) except Exception as e: - logger.error(f"An error occurred:{e}") + print(f"An error occurred:{e}") return app_info + def get_app_abilities(self, package_name: str) -> List[Dict]: + """ + Get the abilities of an application. + + Args: + package_name (str): The package name of the application. + + Returns: + List[Dict]: A list of dictionaries containing the abilities of the application. + """ + result = [] + app_info = self.get_app_info(package_name) + hap_module_infos = app_info.get("hapModuleInfos") + main_entry = app_info.get("mainEntry") + for hap_module_info in hap_module_infos: + # 尝试读取moduleInfo + try: + ability_infos = hap_module_info.get("abilityInfos") + module_main = hap_module_info["mainAbility"] + except Exception as e: + print(f"解析模块信息项失败, {repr(e)}") + continue + # 尝试读取abilityInfo + for ability_info in ability_infos: + try: + is_launcher_ability = False + skills = ability_info['skills'] + if len(skills) > 0 or "action.system.home" in skills[0]["actions"]: + is_launcher_ability = True + icon_ability_info = { + "name": ability_info["name"], + "moduleName": ability_info["moduleName"], + "moduleMainAbility": module_main, + "mainModule": main_entry, + "isLauncherAbility": is_launcher_ability + } + result.append(icon_ability_info) + except Exception as e: + print(f"解析ability_info项失败, {repr(e)}") + continue + return result + + def get_app_main_ability(self, package_name: str) -> Dict: + """ + Get the main ability of an application. + + Args: + package_name (str): The package name of the application to retrieve information for. + + Returns: + Dict: A dictionary containing the main ability of the application. + + """ + if not (abilities := self.get_app_abilities(package_name)): + return {} + for item in abilities: + score = 0 + if (name := item["name"]) and name == item["moduleMainAbility"]: + score += 1 + if (module_name := item["moduleName"]) and module_name == item["mainModule"]: + score += 1 + item["score"] = score + abilities.sort(key=lambda x: (not x["isLauncherAbility"], -x["score"])) + return abilities[0] + @cached_property def toast_watcher(self): - obj = self class _Watcher: @@ -122,7 +239,7 @@ def start(self) -> bool: resp: HypiumResponse = obj._invoke(api, args=["toastShow"]) return resp.result - def get_toast(self, timeout: int = 3) -> str: + def get_toast(self, timeout: int = 3) -> Union[str, None]: api = "Driver.getRecentUiEvent" resp: HypiumResponse = obj._invoke(api, args=[timeout]) if resp.result: @@ -131,23 +248,20 @@ def get_toast(self, timeout: int = 3) -> str: return _Watcher() - @delay def go_back(self): - self.hdc.send_key(KeyCode.BACK) + self._client.hdc.send_key(KeyCode.BACK) - @delay def go_home(self): - self.hdc.send_key(KeyCode.HOME) + self._client.hdc.send_key(KeyCode.HOME) - @delay def press_key(self, key_code: Union[KeyCode, int]): - self.hdc.send_key(key_code) + self._client.hdc.send_key(key_code) def screen_on(self): - self.hdc.wakeup() + self._client.hdc.wakeup() def screen_off(self): - self.hdc.wakeup() + self._client.hdc.wakeup() self.press_key(KeyCode.POWER) @delay @@ -163,6 +277,12 @@ def display_size(self) -> Tuple[int, int]: w, h = resp.result.get("x"), resp.result.get("y") return w, h + def window_size(self) -> Tuple[int, int]: + api = "Driver.getDisplaySize" + resp: HypiumResponse = self._invoke(api) + w, h = resp.result.get("x"), resp.result.get("y") + return w, h + @cached_property def display_rotation(self) -> DisplayRotation: api = "Driver.getDisplayRotation" @@ -174,7 +294,7 @@ def set_display_rotation(self, rotation: DisplayRotation): Sets the display rotation to the specified orientation. Args: - rotation (DisplayRotation): The desired display rotation. This should be an instance of the DisplayRotation enum. + rotation (DisplayRotation): display rotation. """ api = "Driver.setDisplayRotation" self._invoke(api, args=[rotation.value]) @@ -187,7 +307,7 @@ def device_info(self) -> DeviceInfo: Returns: DeviceInfo: An object containing various properties of the device. """ - hdc = self.hdc + hdc = self._client.hdc return DeviceInfo( productName=hdc.product_name(), model=hdc.model(), @@ -203,10 +323,10 @@ def device_info(self) -> DeviceInfo: def open_url(self, url: str, system_browser: bool = True): if system_browser: # Use the system browser - self.hdc.shell(f"aa start -A ohos.want.action.viewData -e entity.system.browsable -U {url}") + self._client.hdc.shell(f"aa start -A ohos.want.action.viewData -e entity.system.browsable -U {url}") else: # Default method - self.hdc.shell(f"aa start -U {url}") + self._client.hdc.shell(f"aa start -U {url}") def pull_file(self, rpath: str, lpath: str): """ @@ -216,7 +336,7 @@ def pull_file(self, rpath: str, lpath: str): rpath (str): The remote path of the file on the device. lpath (str): The local path where the file should be saved. """ - self.hdc.recv_file(rpath, lpath) + self._client.hdc.recv_file(rpath, lpath) def push_file(self, lpath: str, rpath: str): """ @@ -226,7 +346,7 @@ def push_file(self, lpath: str, rpath: str): lpath (str): The local path of the file. rpath (str): The remote path where the file should be saved on the device. """ - self.hdc.send_file(lpath, rpath) + self._client.hdc.send_file(lpath, rpath) def screenshot(self, path: str) -> str: """ @@ -246,9 +366,9 @@ def screenshot(self, path: str) -> str: return path def shell(self, cmd) -> CommandResult: - return self.hdc.shell(cmd) + return self._client.hdc.shell(cmd) - def _to_abs_pos(self, x: Union[int, float], y: Union[int, float]) -> Point: + def to_abs_pos(self, x: Union[int, float], y: Union[int, float]) -> Point: """ Convert percentages to absolute screen coordinates. @@ -270,27 +390,23 @@ def _to_abs_pos(self, x: Union[int, float], y: Union[int, float]) -> Point: y = int(h * y) return Point(int(x), int(y)) - @delay def click(self, x: Union[int, float], y: Union[int, float]): + # todo 使用hdc自带的uinput命令点击速度更快 + self._client.hdc.tap(x, y) + # point = self.to_abs_pos(x, y) + # api = "Driver.click" + # self._invoke(api, args=[point.x, point.y]) - # self.hdc.tap(point.x, point.y) - point = self._to_abs_pos(x, y) - api = "Driver.click" - self._invoke(api, args=[point.x, point.y]) - - @delay def double_click(self, x: Union[int, float], y: Union[int, float]): - point = self._to_abs_pos(x, y) + point = self.to_abs_pos(x, y) api = "Driver.doubleClick" self._invoke(api, args=[point.x, point.y]) - @delay def long_click(self, x: Union[int, float], y: Union[int, float]): - point = self._to_abs_pos(x, y) + point = self.to_abs_pos(x, y) api = "Driver.longClick" self._invoke(api, args=[point.x, point.y]) - @delay def swipe(self, x1, y1, x2, y2, speed=2000): """ Perform a swipe action on the device screen. @@ -300,14 +416,13 @@ def swipe(self, x1, y1, x2, y2, speed=2000): y1 (float): The start Y coordinate as a percentage or absolute value. x2 (float): The end X coordinate as a percentage or absolute value. y2 (float): The end Y coordinate as a percentage or absolute value. - speed (int, optional): The swipe speed in pixels per second. Default is 2000. Range: 200-40000. If not within the range, set to default value of 2000. + speed (int, optional): The swipe speed in pixels per second. Default is 2000. Range: 200-40000, + If not within the range, set to default value of 2000. """ - - point1 = self._to_abs_pos(x1, y1) - point2 = self._to_abs_pos(x2, y2) + point1 = self.to_abs_pos(x1, y1) + point2 = self.to_abs_pos(x2, y2) if speed < 200 or speed > 40000: - logger.warning("`speed` is not in the range[200-40000], Set to default value of 2000.") speed = 2000 api = "Driver.swipe" @@ -322,38 +437,41 @@ def swipe_ext(self): from ._swipe import SwipeExt return SwipeExt(self) - @delay def input_text(self, text: str): """ - Inputs text into the currently focused input field. + 在当前焦点输入框中输入文本 - Note: The input field must have focus before calling this method. + 注意:调用此方法前,输入框必须已获得焦点 Args: - text (str): input value + text: 要输入的文本 + + Returns: + HypiumResponse: API 调用响应 """ return self._invoke("Driver.inputText", args=[{"x": 1, "y": 1}, text]) - def dump_hierarchy(self) -> Dict: + def dump_hierarchy(self) -> Union[dict, None]: """ - Dump the UI hierarchy of the device screen. + 导出界面层次结构 Returns: - Dict: The dumped UI hierarchy as a dictionary. + str: 界面层次结构的 JSON 字符串 """ - # return self._client.invoke_captures("captureLayout").result - return self.hdc.dump_hierarchy() + # 环形缓冲区要设置大一些,才能获取到消息头和尾部。完美解决 + result = self._client.invoke_captures("captureLayout") + if result: + xml_result = result.result + if isinstance(xml_result, dict): + return xml_result + return json.loads(xml_result) + return None @cached_property def gesture(self): from ._gesture import _Gesture return _Gesture(self) - @cached_property - def screenrecord(self): - from ._screenrecord import RecordClient - return RecordClient(self.serial, self) - def _invalidate_cache(self, attribute_name): """ Invalidate the cached property. diff --git a/hmdriver2/hdc.py b/hmdriver2/hdc.py index d9d5405..da33729 100644 --- a/hmdriver2/hdc.py +++ b/hmdriver2/hdc.py @@ -1,14 +1,12 @@ # -*- coding: utf-8 -*- - -import tempfile import json +import tempfile +import os import uuid import shlex import re import subprocess -from typing import Union, List, Dict, Tuple - -from . import logger +from typing import Union, List, Tuple, Dict from .utils import FreePort from .proto import CommandResult, KeyCode from .exception import HdcError, DeviceNotFoundError @@ -17,19 +15,20 @@ def _execute_command(cmdargs: Union[str, List[str]]) -> CommandResult: if isinstance(cmdargs, (list, tuple)): cmdline: str = ' '.join(list(map(shlex.quote, cmdargs))) - elif isinstance(cmdargs, str): + else: cmdline = cmdargs - logger.debug(cmdline) try: - process = subprocess.Popen(cmdline, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, shell=True) + process = subprocess.Popen(cmdline, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True) output, error = process.communicate() output = output.decode('utf-8') error = error.decode('utf-8') exit_code = process.returncode - if output.lower().__contains__('error:'): + if 'error:' in output.lower() or '[fail]' in output.lower(): return CommandResult("", output, -1) return CommandResult(output, error, exit_code) @@ -37,9 +36,21 @@ def _execute_command(cmdargs: Union[str, List[str]]) -> CommandResult: return CommandResult("", str(e), -1) +def _build_hdc_prefix() -> str: + """ + Construct the hdc command prefix based on environment variables. + """ + host = os.getenv("HDC_SERVER_HOST") + port = os.getenv("HDC_SERVER_PORT") + if host and port: + return f"hdc -s {host}:{port}" + return "hdc" + + def list_devices() -> List[str]: devices = [] - result = _execute_command('hdc list targets') + hdc_prefix = _build_hdc_prefix() + result = _execute_command(f"{hdc_prefix} list targets") if result.exit_code == 0 and result.output: lines = result.output.strip().split('\n') for line in lines: @@ -56,6 +67,8 @@ def list_devices() -> List[str]: class HdcWrapper: def __init__(self, serial: str) -> None: self.serial = serial + self.hdc_prefix = _build_hdc_prefix() + if not self.is_online(): raise DeviceNotFoundError(f"Device [{self.serial}] not found") @@ -63,55 +76,63 @@ def is_online(self): _serials = list_devices() return True if self.serial in _serials else False - def forward_port(self, rport: int) -> int: - lport: int = FreePort().get() - result = _execute_command(f"hdc -t {self.serial} fport tcp:{lport} tcp:{rport}") + def forward_port(self, r_port: int) -> int: + l_port: int = FreePort().get() + result = _execute_command(f"{self.hdc_prefix} -t {self.serial} fport tcp:{l_port} tcp:{r_port}") if result.exit_code != 0: raise HdcError("HDC forward port error", result.error) - return lport + return l_port - def rm_forward(self, lport: int, rport: int) -> int: - result = _execute_command(f"hdc -t {self.serial} fport rm tcp:{lport} tcp:{rport}") + def rm_forward(self, l_port: int, r_port: int) -> int: + result = _execute_command(f"{self.hdc_prefix} -t {self.serial} fport rm tcp:{l_port} tcp:{r_port}") if result.exit_code != 0: raise HdcError("HDC rm forward error", result.error) - return lport + return l_port def list_fport(self) -> List: """ eg.['tcp:10001 tcp:8012', 'tcp:10255 tcp:8012'] """ - result = _execute_command(f"hdc -t {self.serial} fport ls") + result = _execute_command(f"{self.hdc_prefix} -t {self.serial} fport ls") if result.exit_code != 0: raise HdcError("HDC forward list error", result.error) pattern = re.compile(r"tcp:\d+ tcp:\d+") return pattern.findall(result.output) def send_file(self, lpath: str, rpath: str): - result = _execute_command(f"hdc -t {self.serial} file send {lpath} {rpath}") + result = _execute_command(f"{self.hdc_prefix} -t {self.serial} file send {lpath} {rpath}") if result.exit_code != 0: raise HdcError("HDC send file error", result.error) return result def recv_file(self, rpath: str, lpath: str): - result = _execute_command(f"hdc -t {self.serial} file recv {rpath} {lpath}") + result = _execute_command(f"{self.hdc_prefix} -t {self.serial} file recv {rpath} {lpath}") if result.exit_code != 0: raise HdcError("HDC receive file error", result.error) return result def shell(self, cmd: str, error_raise=True) -> CommandResult: - result = _execute_command(f"hdc -t {self.serial} shell {cmd}") + # ensure the command is wrapped in double quotes + if cmd[0] != '\"': + cmd = "\"" + cmd + if cmd[-1] != '\"': + cmd += '\"' + result = _execute_command(f"{self.hdc_prefix} -t {self.serial} shell {cmd}") if result.exit_code != 0 and error_raise: raise HdcError("HDC shell error", f"{cmd}\n{result.output}\n{result.error}") return result def uninstall(self, bundlename: str): - result = _execute_command(f"hdc -t {self.serial} uninstall {bundlename}") + result = _execute_command(f"{self.hdc_prefix} -t {self.serial} uninstall {bundlename}") if result.exit_code != 0: raise HdcError("HDC uninstall error", result.output) return result def install(self, apkpath: str): - result = _execute_command(f"hdc -t {self.serial} install '{apkpath}'") + # Ensure the path is properly quoted for Windows + quoted_path = f'"{apkpath}"' + + result = _execute_command(f"{self.hdc_prefix} -t {self.serial} install {quoted_path}") if result.exit_code != 0: raise HdcError("HDC install error", result.error) return result @@ -125,10 +146,10 @@ def has_app(self, package_name: str) -> bool: data = self.shell("bm dump -a").output return True if package_name in data else False - def start_app(self, package_name: str, ability_name: str): + def app_start(self, package_name: str, ability_name: str): return self.shell(f"aa start -a {ability_name} -b {package_name}") - def stop_app(self, package_name: str): + def app_stop(self, package_name: str): return self.shell(f"aa force-stop {package_name}") def current_app(self) -> Tuple[str, str]: @@ -140,23 +161,23 @@ def current_app(self) -> Tuple[str, str]: If no foreground application is found, returns (None, None). """ - def __extract_info(output: str): - results = [] + def __extract_info(_output: str): + _results = [] - mission_blocks = re.findall(r'Mission ID #[\s\S]*?isKeepAlive: false\s*}', output) + mission_blocks = re.findall(r'Mission ID #[\s\S]*?isKeepAlive: false\s*}', _output) if not mission_blocks: - return results + return _results for block in mission_blocks: if 'state #FOREGROUND' in block: - bundle_name_match = re.search(r'bundle name \[(.*?)\]', block) - main_name_match = re.search(r'main name \[(.*?)\]', block) + bundle_name_match = re.search(r'bundle name \[(.*?)\\]', block) + main_name_match = re.search(r'main name \[(.*?)\\]', block) if bundle_name_match and main_name_match: package_name = bundle_name_match.group(1) page_name = main_name_match.group(1) - results.append((package_name, page_name)) + _results.append((package_name, page_name)) - return results + return _results data: CommandResult = self.shell("aa dump -l") output = data.output @@ -215,8 +236,8 @@ def display_size(self) -> Tuple[int, int]: if match: w = int(match.group(1)) h = int(match.group(2)) - return (w, h) - return (0, 0) + return w, h + return 0, 0 def send_key(self, key_code: Union[KeyCode, int]) -> None: if isinstance(key_code, KeyCode): @@ -229,7 +250,8 @@ def send_key(self, key_code: Union[KeyCode, int]) -> None: self.shell(f"uitest uiInput keyEvent {key_code}") def tap(self, x: int, y: int) -> None: - self.shell(f"uitest uiInput click {x} {y}") + # 点击用这个方法,速度更快 参考文档 https://gitee.com/openharmony/docs/blob/master/zh-cn/application-dev/dfx/uinput.md + self.shell(f"uinput -T -c {x} {y}") def swipe(self, x1, y1, x2, y2, speed=1000): self.shell(f"uitest uiInput swipe {x1} {y1} {x2} {y2} {speed}") @@ -254,10 +276,9 @@ def dump_hierarchy(self) -> Dict: self.recv_file(_tmp_path, path) try: - with open(path, 'r') as file: + with open(path, 'r', encoding='utf8') as file: data = json.load(file) except Exception as e: - logger.error(f"Error loading JSON file: {e}") data = {} - return data + return data \ No newline at end of file diff --git a/hmdriver2/proto.py b/hmdriver2/proto.py index e987cc9..59f08da 100644 --- a/hmdriver2/proto.py +++ b/hmdriver2/proto.py @@ -1,8 +1,7 @@ # -*- coding: utf-8 -*- - import json from enum import Enum -from typing import Union, List +from typing import Union, Dict, List from dataclasses import dataclass, asdict @@ -35,12 +34,12 @@ def from_value(cls, value): class AppState: - INIT = 0 # 初始化状态,应用正在初始化 - READY = 1 # 就绪状态,应用已初始化完毕 + INIT = 0 # 初始化状态,应用正在初始化 + READY = 1 # 就绪状态,应用已初始化完毕 FOREGROUND = 2 # 前台状态,应用位于前台 - FOCUS = 3 # 获焦状态。(预留状态,当前暂不支持) + FOCUS = 3 # 获焦状态。(预留状态,当前暂不支持) BACKGROUND = 4 # 后台状态,应用位于后台 - EXIT = 5 # 退出状态,应用已退出 + EXIT = 5 # 退出状态,应用已退出 @dataclass @@ -64,8 +63,9 @@ class HypiumResponse: {"result":null,"exception":"Can not connect to AAMS, RET_ERR_CONNECTION_EXIST"} {"exception":{"code":401,"message":"(PreProcessing: APiCallInfoChecker)Illegal argument count"}} """ - result: Union[List, bool, str, None] = None - exception: Union[List, bool, str, None] = None + result: Union[List, Dict, bool, str, None] = None + exception: Union[List, Dict, bool, str, None] = None + pts:Union[List, Dict, bool, str, None] = None @dataclass @@ -75,7 +75,7 @@ class ByData: @dataclass class DriverData: - value: str # "Driver#0" + value: str # "Driver#0" @dataclass @@ -85,8 +85,13 @@ class ComponentData: @dataclass class Point: - x: int - y: int + def __init__(self, x, y): + self.x = x + self.y = y + + def __iter__(self): + yield self.x + yield self.y def to_tuple(self): return self.x, self.y @@ -100,14 +105,21 @@ def to_dict(self): @dataclass class Bounds: - left: int - top: int - right: int - bottom: int + + def __init__(self, left, top, right, bottom): + self.left = left + self.top = top + self.right = right + self.bottom = bottom + + def __iter__(self): + yield self.left + yield self.top + yield self.right + yield self.bottom def get_center(self) -> Point: - return Point(int((self.left + self.right) / 2), - int((self.top + self.bottom) / 2)) + return Point(int((self.left + self.right) / 2), int((self.top + self.bottom) / 2)) @dataclass @@ -126,7 +138,7 @@ class ElementInfo: isLongClickable: bool isScrollable: bool bounds: Bounds - boundsCenter: Point + boundsCenter: Union[Point, None] = None def __str__(self) -> str: return json.dumps(asdict(self), indent=4) @@ -471,4 +483,4 @@ class KeyCode(Enum): BTN_6 = 3106 # 按键6 BTN_7 = 3107 # 按键7 BTN_8 = 3108 # 按键8 - BTN_9 = 3109 # 按键9 \ No newline at end of file + BTN_9 = 3109 # 按键9 diff --git a/hmdriver2/utils.py b/hmdriver2/utils.py index cf4f75a..37bf78b 100644 --- a/hmdriver2/utils.py +++ b/hmdriver2/utils.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- - - import time import socket import re diff --git a/test_example.py b/test_example.py new file mode 100644 index 0000000..595604d --- /dev/null +++ b/test_example.py @@ -0,0 +1,70 @@ +import time +from hmdriver2.driver import Driver + + +class TestExample: + demo_test = None + + @classmethod + def setup_class(cls): + """类级别的前置方法,整个测试类只执行一次""" + cls.demo_test = Driver("2PM0224423003375") + cls.demo_test.app_start(package_name="com.qihoo.smartoh", page_name="EntryAbility") + + @classmethod + def teardown_class(cls): + """类级别的后置方法,整个测试类只执行一次""" + cls.demo_test = None + print("Driver closed") + + def setup_method(self): + """完全对齐 uiautomator2的元素查询,属性获取""" + # text 查询 + ele = self.demo_test(text='settings') + print(ele.info) + ele.click() + self.demo_test.go_back() + + # 对齐uiautomator2 仅方法名不一样 + # 1. 判断元素消失或出现 + # 在指定时间内,如10秒内等待元素消失 或 出现。当符合条件是返回True; 否则10秒超时后,返回False + # 1.1 text 等其他选择器 + ele_exist = self.demo_test(text='settings').wait(exists=True, timeout=10) # exists=False 当元消失返回True + print("元素状态", ele_exist) + + # 1.2 xpath + xpath_exist = self.demo_test.xpath('//*[contains(@text,"card_5081e5dd2d6a")]').wait(exists=True, timeout=10) + print("元素状态", xpath_exist) + + # 2. 元素查询 + # 2.1 text 等其他选择器 + # 10秒内一直查询,直到10秒超时: 默认返回多元素列表,对齐 uiautomator2 + text_ele = self.demo_test(text='settings').find_components(timeout=10) + if text_ele: + print(text_ele) + # 链式调用 + aa = text_ele[0].bounds + text_ele[1].click() + + # 2.2 xpath 选择器. 没有超时逻辑,可自行封装,对齐 uiautomator2 + # 2.3 高级功能 支持xpath 表达式列表。并发查询如 ['//*[contains(@text,"card_123")]', '//*[contains(@text,"card_345")]'] + xpath_elements = self.demo_test.xpath('//*[contains(@text,"card_5081e5dd2d6a")]').find_all() + if xpath_elements: + print(xpath_elements) + # 链式调用 + bb = xpath_elements[0].bounds + xpath_elements[1].click() + + def teardown_method(self): + """方法级别的后置方法,每个测试方法执行后都会执行""" + # 返回首页 + self.demo_test.go_back() + time.sleep(3) + + def test_example_case(self): + # 判断开流 + rr = self.demo_test.xpath('//*[@text="camera_rtc_duplex"]') + time.sleep(2) + assert rr.enabled == True + +# pytest test_example.py -s -v --count=20