diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/azure/__init__.py b/python/packages/autogen-ext/src/autogen_ext/tools/azure/__init__.py index 733ba05d1a9e..12851185e60c 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/azure/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/azure/__init__.py @@ -4,6 +4,7 @@ SearchQuery, SearchResult, SearchResults, + VectorizableTextQuery, ) from ._config import AzureAISearchConfig @@ -14,4 +15,5 @@ "SearchResult", "SearchResults", "AzureAISearchConfig", + "VectorizableTextQuery", ] diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/azure/_ai_search.py b/python/packages/autogen-ext/src/autogen_ext/tools/azure/_ai_search.py index cf0570d82727..447a716b695e 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/azure/_ai_search.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/azure/_ai_search.py @@ -1,112 +1,87 @@ +from __future__ import annotations + +import asyncio import logging import time from abc import ABC, abstractmethod from contextvars import ContextVar -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Type, TypeVar, Union, cast, overload - -from autogen_core import CancellationToken, ComponentModel +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Literal, + Optional, + Protocol, + Union, +) + +from autogen_core import CancellationToken, Component from autogen_core.tools import BaseTool, ToolSchema -from azure.core.credentials import AzureKeyCredential, TokenCredential +from azure.core.credentials import AzureKeyCredential +from azure.core.credentials_async import AsyncTokenCredential from azure.core.exceptions import HttpResponseError, ResourceNotFoundError from azure.search.documents.aio import SearchClient from pydantic import BaseModel, Field -if TYPE_CHECKING: - from azure.search.documents.models import VectorizableTextQuery - -_has_retry_policy = False -try: - from azure.core.pipeline.policies import RetryPolicy # type: ignore[assignment] - - _has_retry_policy = True -except ImportError: - - class RetryPolicy: # type: ignore - def __init__(self, retry_mode: str = "fixed", retry_total: int = 3, **kwargs: Any) -> None: - pass - - _has_retry_policy = False - -HAS_RETRY_POLICY = _has_retry_policy - -has_azure_search = False - -if not TYPE_CHECKING: - try: - from azure.search.documents.models import VectorizableTextQuery - - has_azure_search = True - except ImportError: +from ._config import ( + DEFAULT_API_VERSION, + AzureAISearchConfig, +) - class VectorizableTextQuery: - """Fallback implementation when Azure SDK is not installed.""" +SearchDocument = Dict[str, Any] +MetadataDict = Dict[str, Any] +ContentDict = Dict[str, Any] - def __init__(self, text: str, k: int, fields: Union[str, List[str]]) -> None: - self.text = text - self.k = k - self.fields = fields if isinstance(fields, str) else ",".join(fields) - - -class _FallbackAzureAISearchConfig: - """Fallback configuration class for Azure AI Search when the main config module is not available. +if TYPE_CHECKING: + from azure.search.documents.aio import AsyncSearchItemPaged - This class provides a simple dictionary-based configuration object that mimics the behavior - of the AzureAISearchConfig from the _config module. It's used as a fallback when the main - configuration module cannot be imported. + SearchResultsIterable = AsyncSearchItemPaged[SearchDocument] +else: + SearchResultsIterable = Any - Args: - **kwargs (Any): Keyword arguments containing configuration values - """ +logger = logging.getLogger(__name__) - def __init__(self, **kwargs: Any): - self.name = kwargs.get("name", "") - self.description = kwargs.get("description", "") - self.endpoint = kwargs.get("endpoint", "") - self.index_name = kwargs.get("index_name", "") - self.credential = kwargs.get("credential", None) - self.api_version = kwargs.get("api_version", "") - self.query_type = kwargs.get("query_type", "simple") - self.search_fields = kwargs.get("search_fields", None) - self.select_fields = kwargs.get("select_fields", None) - self.vector_fields = kwargs.get("vector_fields", None) - self.filter = kwargs.get("filter", None) - self.top = kwargs.get("top", None) - self.retry_enabled = kwargs.get("retry_enabled", False) - self.retry_mode = kwargs.get("retry_mode", "fixed") - self.retry_max_attempts = kwargs.get("retry_max_attempts", 3) - self.enable_caching = kwargs.get("enable_caching", False) - self.cache_ttl_seconds = kwargs.get("cache_ttl_seconds", 300) - - -AzureAISearchConfig: Any +if TYPE_CHECKING: + from azure.search.documents.models import ( + VectorizableTextQuery, + VectorizedQuery, + VectorQuery, + ) try: - from ._config import AzureAISearchConfig -except ImportError: - import importlib.util - import os - - current_dir = os.path.dirname(os.path.abspath(__file__)) - config_path = os.path.join(current_dir, "_config.py") - config_module = None + from azure.search.documents.models import VectorizableTextQuery, VectorizedQuery, VectorQuery - spec_config = importlib.util.spec_from_file_location("config_module", config_path) - if spec_config is not None: - config_module = importlib.util.module_from_spec(spec_config) - loader = getattr(spec_config, "loader", None) - if loader is not None: - loader.exec_module(config_module) + has_azure_search = True +except ImportError: + has_azure_search = False + logger.error( + "The 'azure-search-documents' package is required for this tool but was not found. " + "Please install it with: uv add install azure-search-documents" + ) - if config_module is not None and hasattr(config_module, "AzureAISearchConfig"): - AzureAISearchConfig = config_module.AzureAISearchConfig - else: - AzureAISearchConfig = _FallbackAzureAISearchConfig +if TYPE_CHECKING: + from typing import Protocol + + class SearchClientProtocol(Protocol): + async def search(self, **kwargs: Any) -> SearchResultsIterable: ... + async def close(self) -> None: ... +else: + SearchClientProtocol = Any + +__all__ = [ + "AzureAISearchTool", + "BaseAzureAISearchTool", + "SearchQuery", + "SearchResults", + "SearchResult", + "VectorizableTextQuery", + "VectorizedQuery", + "VectorQuery", +] logger = logging.getLogger(__name__) -T = TypeVar("T", bound="BaseAzureAISearchTool") -ExpectedType = TypeVar("ExpectedType") - class SearchQuery(BaseModel): """Search query parameters. @@ -127,13 +102,13 @@ class SearchResult(BaseModel): Args: score (float): The search score. - content (Dict[str, Any]): The document content. - metadata (Dict[str, Any]): Additional metadata about the document. + content (ContentDict): The document content. + metadata (MetadataDict): Additional metadata about the document. """ score: float = Field(description="The search score") - content: Dict[str, Any] = Field(description="The document content") - metadata: Dict[str, Any] = Field(description="Additional metadata about the document") + content: ContentDict = Field(description="The document content") + metadata: MetadataDict = Field(description="Additional metadata about the document") class SearchResults(BaseModel): @@ -146,7 +121,100 @@ class SearchResults(BaseModel): results: List[SearchResult] = Field(description="List of search results") -class BaseAzureAISearchTool(BaseTool[SearchQuery, SearchResults], ABC): +class EmbeddingProvider(Protocol): + """Protocol defining the interface for embedding generation.""" + + async def _get_embedding(self, query: str) -> List[float]: + """Generate embedding vector for the query text.""" + ... + + +class EmbeddingProviderMixin: + """Mixin class providing embedding generation functionality.""" + + search_config: AzureAISearchConfig + + async def _get_embedding(self, query: str) -> List[float]: + """Generate embedding vector for the query text.""" + if not hasattr(self, "search_config"): + raise ValueError("Host class must have a search_config attribute") + + search_config = self.search_config + embedding_provider = getattr(search_config, "embedding_provider", None) + embedding_model = getattr(search_config, "embedding_model", None) + + if not embedding_provider or not embedding_model: + raise ValueError( + "Client-side embedding is not configured. `embedding_provider` and `embedding_model` must be set." + ) from None + + if embedding_provider.lower() == "azure_openai": + try: + from azure.identity import DefaultAzureCredential + from openai import AsyncAzureOpenAI + except ImportError: + raise ImportError( + "Azure OpenAI SDK is required for client-side embedding generation. " + "Please install it with: uv add openai azure-identity" + ) from None + + api_key = getattr(search_config, "openai_api_key", None) + api_version = getattr(search_config, "openai_api_version", "2023-11-01") + endpoint = getattr(search_config, "openai_endpoint", None) + + if not endpoint: + raise ValueError( + "Azure OpenAI endpoint (`openai_endpoint`) must be provided for client-side Azure OpenAI embeddings." + ) from None + + if api_key: + azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=endpoint) + else: + + def get_token() -> str: + credential = DefaultAzureCredential() + token = credential.get_token("https://cognitiveservices.azure.com/.default") + if not token or not token.token: + raise ValueError("Failed to acquire token using DefaultAzureCredential for Azure OpenAI.") + return token.token + + azure_client = AsyncAzureOpenAI( + azure_ad_token_provider=get_token, api_version=api_version, azure_endpoint=endpoint + ) + + try: + response = await azure_client.embeddings.create(model=embedding_model, input=query) + return response.data[0].embedding + except Exception as e: + raise ValueError(f"Failed to generate embeddings with Azure OpenAI: {str(e)}") from e + + elif embedding_provider.lower() == "openai": + try: + from openai import AsyncOpenAI + except ImportError: + raise ImportError( + "OpenAI SDK is required for client-side embedding generation. " + "Please install it with: uv add openai" + ) from None + + api_key = getattr(search_config, "openai_api_key", None) + openai_client = AsyncOpenAI(api_key=api_key) + + try: + response = await openai_client.embeddings.create(model=embedding_model, input=query) + return response.data[0].embedding + except Exception as e: + raise ValueError(f"Failed to generate embeddings with OpenAI: {str(e)}") from e + else: + raise ValueError( + f"Unsupported client-side embedding provider: {embedding_provider}. " + "Currently supported providers are 'azure_openai' and 'openai'." + ) + + +class BaseAzureAISearchTool( + BaseTool[SearchQuery, SearchResults], Component[AzureAISearchConfig], EmbeddingProvider, ABC +): """Abstract base class for Azure AI Search tools. This class defines the common interface and functionality for all Azure AI Search tools. @@ -161,15 +229,18 @@ class BaseAzureAISearchTool(BaseTool[SearchQuery, SearchResults], ABC): Use concrete implementations or the factory methods in AzureAISearchTool. """ + component_config_schema = AzureAISearchConfig + component_provider_override = "autogen_ext.tools.azure.BaseAzureAISearchTool" + def __init__( self, name: str, endpoint: str, index_name: str, - credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]], + credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]], description: Optional[str] = None, - api_version: str = "2023-11-01", - query_type: Literal["keyword", "fulltext", "vector", "semantic"] = "keyword", + api_version: str = DEFAULT_API_VERSION, + query_type: Literal["simple", "full", "semantic", "vector"] = "simple", search_fields: Optional[List[str]] = None, select_fields: Optional[List[str]] = None, vector_fields: Optional[List[str]] = None, @@ -178,6 +249,11 @@ def __init__( semantic_config_name: Optional[str] = None, enable_caching: bool = False, cache_ttl_seconds: int = 300, + embedding_provider: Optional[str] = None, + embedding_model: Optional[str] = None, + openai_api_key: Optional[str] = None, + openai_api_version: Optional[str] = None, + openai_endpoint: Optional[str] = None, ): """Initialize the Azure AI Search tool. @@ -185,10 +261,10 @@ def __init__( name (str): The name of this tool instance endpoint (str): The full URL of your Azure AI Search service index_name (str): Name of the search index to query - credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): Azure credential for authentication (API key or token) + credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): Azure credential for authentication description (Optional[str]): Optional description explaining the tool's purpose - api_version (str): Azure AI Search API version to use - query_type (Literal["keyword", "fulltext", "vector", "semantic"]): Type of search to perform + api_version (Optional[str]): Azure AI Search API version to use + query_type (Literal["simple", "full", "semantic", "vector"]): Type of search to perform search_fields (Optional[List[str]]): Fields to search within documents select_fields (Optional[List[str]]): Fields to return in search results vector_fields (Optional[List[str]]): Fields to use for vector search @@ -197,6 +273,11 @@ def __init__( semantic_config_name (Optional[str]): Semantic configuration name for enhanced results enable_caching (bool): Whether to cache search results cache_ttl_seconds (int): How long to cache results in seconds + embedding_provider (Optional[str]): Name of embedding provider for client-side embeddings + embedding_model (Optional[str]): Model name for client-side embeddings + openai_api_key (Optional[str]): API key for OpenAI/Azure OpenAI embeddings + openai_api_version (Optional[str]): API version for Azure OpenAI embeddings + openai_endpoint (Optional[str]): Endpoint URL for Azure OpenAI embeddings """ if not has_azure_search: raise ImportError( @@ -217,12 +298,14 @@ def __init__( description=description, ) - self.search_config = AzureAISearchConfig( + processed_credential = self._process_credential(credential) + + self.search_config: AzureAISearchConfig = AzureAISearchConfig( name=name, description=description, endpoint=endpoint, index_name=index_name, - credential=self._process_credential(credential), + credential=processed_credential, api_version=api_version, query_type=query_type, search_fields=search_fields, @@ -233,35 +316,77 @@ def __init__( semantic_config_name=semantic_config_name, enable_caching=enable_caching, cache_ttl_seconds=cache_ttl_seconds, + embedding_provider=embedding_provider, + embedding_model=embedding_model, + openai_api_key=openai_api_key, + openai_api_version=openai_api_version, + openai_endpoint=openai_endpoint, ) self._endpoint = endpoint self._index_name = index_name - self._credential = credential + self._credential = processed_credential self._api_version = api_version + self._client: Optional[SearchClient] = None self._cache: Dict[str, Dict[str, Any]] = {} + if self.search_config.api_version == "2023-11-01" and self.search_config.vector_fields: + warning_message = ( + f"When explicitly setting api_version='{self.search_config.api_version}' for vector search: " + f"If client-side embedding is NOT configured (e.g., `embedding_model` is not set), " + f"this tool defaults to service-side vectorization (VectorizableTextQuery), which may fail or have limitations with this API version. " + f"If client-side embedding IS configured, the tool will use VectorizedQuery, which is generally compatible. " + f"For robust vector search, consider omitting api_version (recommended to use SDK default) or use a newer API version." + ) + logger.warning(warning_message) + async def close(self) -> None: - """Explicitly close the Azure SearchClient if needed (for cleanup in long-running apps/tests).""" + """Explicitly close the Azure SearchClient if needed (for cleanup).""" if self._client is not None: - await self._client.close() - self._client = None + try: + await self._client.close() + except Exception: + pass + finally: + self._client = None def _process_credential( - self, credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]] - ) -> Union[AzureKeyCredential, TokenCredential]: - """Process credential to ensure it's the correct type.""" + self, credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]] + ) -> Union[AzureKeyCredential, AsyncTokenCredential]: + """Process credential to ensure it's the correct type for async SearchClient. + + Converts dictionary credentials with 'api_key' to AzureKeyCredential objects. + + Args: + credential: The credential in either object or dictionary form + + Returns: + A properly formatted credential object + + Raises: + ValueError: If the credential dictionary doesn't contain an 'api_key' + TypeError: If the credential is not of a supported type + """ if isinstance(credential, dict): if "api_key" in credential: return AzureKeyCredential(credential["api_key"]) - raise ValueError( - "If credential is a dict, it must contain an 'api_key' key with your API key as the value" - ) from None - return credential + raise ValueError("If credential is a dict, it must contain an 'api_key' key") + + if isinstance(credential, (AzureKeyCredential, AsyncTokenCredential)): + return credential + + raise TypeError("Credential must be AzureKeyCredential, AsyncTokenCredential, or a valid dict") async def _get_client(self) -> SearchClient: - """Get the search client for the configured index.""" + """Get the search client for the configured index. + + Returns: + SearchClient: Initialized search client + + Raises: + ValueError: If index doesn't exist or authentication fails + """ if self._client is not None: return self._client @@ -272,23 +397,14 @@ async def _get_client(self) -> SearchClient: credential=self.search_config.credential, api_version=self.search_config.api_version, ) - - assert self._client is not None return self._client except ResourceNotFoundError as e: - raise ValueError( - f"Index '{self.search_config.index_name}' not found. " - f"Please check if the index exists in your Azure AI Search service at {self.search_config.endpoint}" - ) from e + raise ValueError(f"Index '{self.search_config.index_name}' not found in Azure AI Search service.") from e except HttpResponseError as e: - if "401" in str(e): - raise ValueError( - f"Authentication failed. Please check your API key or credentials. Error: {str(e)}" - ) from e - elif "403" in str(e): - raise ValueError( - f"Permission denied. Please check that your credentials have access to this index. Error: {str(e)}" - ) from e + if e.status_code == 401: + raise ValueError("Authentication failed. Please check your credentials.") from e + elif e.status_code == 403: + raise ValueError("Permission denied to access this index.") from e else: raise ValueError(f"Error connecting to Azure AI Search: {str(e)}") from e except Exception as e: @@ -304,311 +420,170 @@ async def run( cancellation_token: Optional token to cancel the operation Returns: - Search results - """ - if isinstance(args, str) and not args.strip(): - raise ValueError("Invalid search query format: Query cannot be empty") + SearchResults: Container with search results and metadata + Raises: + ValueError: If the search query is empty or invalid + ValueError: If there is an authentication error or other search issue + asyncio.CancelledError: If the operation is cancelled + """ if isinstance(args, str): + if not args.strip(): + raise ValueError("Search query cannot be empty") search_query = SearchQuery(query=args) elif isinstance(args, dict) and "query" in args: search_query = SearchQuery(query=args["query"]) elif isinstance(args, SearchQuery): search_query = args else: - raise ValueError(f"Invalid search query format: {args}. Expected string, dict with 'query', or SearchQuery") + raise ValueError("Invalid search query format. Expected string, dict with 'query', or SearchQuery") + + if cancellation_token is not None and cancellation_token.is_cancelled(): + raise asyncio.CancelledError("Operation cancelled") + + cache_key = "" + if self.search_config.enable_caching: + cache_key_parts = [ + search_query.query, + str(self.search_config.top), + self.search_config.query_type, + ",".join(sorted(self.search_config.search_fields or [])), + ",".join(sorted(self.search_config.select_fields or [])), + ",".join(sorted(self.search_config.vector_fields or [])), + str(self.search_config.filter or ""), + str(self.search_config.semantic_config_name or ""), + ] + cache_key = ":".join(filter(None, cache_key_parts)) + if cache_key in self._cache: + cache_entry = self._cache[cache_key] + cache_age = time.time() - cache_entry["timestamp"] + if cache_age < self.search_config.cache_ttl_seconds: + logger.debug(f"Using cached results for query: {search_query.query}") + return SearchResults( + results=[ + SearchResult(score=r.score, content=r.content, metadata=r.metadata) + for r in cache_entry["results"] + ] + ) try: - if cancellation_token is not None and cancellation_token.is_cancelled(): - raise Exception("Operation cancelled") + search_kwargs: Dict[str, Any] = {} - if self.search_config.enable_caching: - cache_key = f"{search_query.query}:{self.search_config.top}" - if cache_key in self._cache: - cache_entry = self._cache[cache_key] - cache_age = time.time() - cache_entry["timestamp"] - if cache_age < self.search_config.cache_ttl_seconds: - logger.debug(f"Using cached results for query: {search_query.query}") - return SearchResults( - results=[ - SearchResult(score=r.score, content=r.content, metadata=r.metadata) - for r in cache_entry["results"] - ] - ) + if self.search_config.query_type != "vector": + search_kwargs["search_text"] = search_query.query + search_kwargs["query_type"] = self.search_config.query_type + + if self.search_config.search_fields: + search_kwargs["search_fields"] = self.search_config.search_fields # type: ignore[assignment] - search_options: Dict[str, Any] = {} - search_options["query_type"] = self.search_config.query_type + if self.search_config.query_type == "semantic" and self.search_config.semantic_config_name: + search_kwargs["semantic_configuration_name"] = self.search_config.semantic_config_name if self.search_config.select_fields: - search_options["select"] = self.search_config.select_fields + search_kwargs["select"] = self.search_config.select_fields # type: ignore[assignment] + if self.search_config.filter: + search_kwargs["filter"] = str(self.search_config.filter) + if self.search_config.top is not None: + search_kwargs["top"] = self.search_config.top # type: ignore[assignment] - if self.search_config.search_fields: - search_options["search_fields"] = self.search_config.search_fields + if self.search_config.vector_fields and len(self.search_config.vector_fields) > 0: + if not search_query.query: + raise ValueError("Query text cannot be empty for vector search operations") - if self.search_config.filter: - search_options["filter"] = self.search_config.filter + use_client_side_embeddings = bool( + self.search_config.embedding_model and self.search_config.embedding_provider + ) - if self.search_config.top is not None: - search_options["top"] = self.search_config.top - - if self.search_config.query_type == "fulltext" and self.search_config.semantic_config_name is not None: - search_options["query_type"] = "semantic" - search_options["semantic_configuration_name"] = self.search_config.semantic_config_name - - text_query = search_query.query - if self.search_config.query_type == "vector" or ( - self.search_config.vector_fields and len(self.search_config.vector_fields) > 0 - ): - if self.search_config.vector_fields: - vector_fields_list = self.search_config.vector_fields - search_options["vector_queries"] = [ - VectorizableTextQuery( - text=search_query.query, k_nearest_neighbors=int(self.search_config.top or 5), fields=field + vector_queries: List[Union[VectorizedQuery, VectorizableTextQuery]] = [] + if use_client_side_embeddings: + from azure.search.documents.models import VectorizedQuery + + embedding_vector: List[float] = await self._get_embedding(search_query.query) + for field_spec in self.search_config.vector_fields: + fields = field_spec if isinstance(field_spec, str) else ",".join(field_spec) + vector_queries.append( + VectorizedQuery( + vector=embedding_vector, + k_nearest_neighbors=self.search_config.top or 5, + fields=fields, + kind="vector", + ) + ) + else: + from azure.search.documents.models import VectorizableTextQuery + + for field in self.search_config.vector_fields: + fields = field if isinstance(field, str) else ",".join(field) + vector_queries.append( + VectorizableTextQuery( # type: ignore + text=search_query.query, + k_nearest_neighbors=self.search_config.top or 5, + fields=fields, + kind="vectorizable", + ) ) - for field in vector_fields_list - ] - - client = await self._get_client() - results: List[SearchResult] = [] - # Use the persistent client directly. Do NOT close after each operation. - # WARNING: The SearchClient must live as long as the tool/agent is in use. - search_future = client.search(text_query, **search_options) # type: ignore + search_kwargs["vector_queries"] = vector_queries # type: ignore[assignment] if cancellation_token is not None: - import asyncio - - # Using explicit type ignores to handle Azure SDK type complexity - async def awaitable_wrapper(): # type: ignore # pyright: ignore[reportUnknownVariableType,reportUnknownLambdaType,reportUnknownMemberType] - return await search_future # pyright: ignore[reportUnknownVariableType] + dummy_task = asyncio.create_task(asyncio.sleep(60)) + cancellation_token.link_future(dummy_task) - task = asyncio.create_task(awaitable_wrapper()) # type: ignore # pyright: ignore[reportUnknownVariableType] - cancellation_token.link_future(task) # pyright: ignore[reportUnknownArgumentType] - search_results = await task # pyright: ignore[reportUnknownVariableType] + def is_cancelled() -> bool: + return cancellation_token.is_cancelled() else: - search_results = await search_future # pyright: ignore[reportUnknownVariableType] - async for doc in search_results: # type: ignore - search_doc: Any = doc - doc_dict: Dict[str, Any] = {} + def is_cancelled() -> bool: + return False + + client = await self._get_client() + search_results: SearchResultsIterable = await client.search(**search_kwargs) # type: ignore[arg-type] + + results: List[SearchResult] = [] + async for doc in search_results: + if is_cancelled(): + raise asyncio.CancelledError("Operation was cancelled") try: - if hasattr(search_doc, "items") and callable(search_doc.items): - dict_like_doc = cast(Dict[str, Any], search_doc) - for key, value in dict_like_doc.items(): - doc_dict[str(key)] = value - else: - for key in [ - k - for k in dir(search_doc) - if not k.startswith("_") and not callable(getattr(search_doc, k, None)) - ]: - doc_dict[key] = getattr(search_doc, key) + metadata: Dict[str, Any] = {} + content: Dict[str, Any] = {} + + for key, value in doc.items(): + if isinstance(key, str) and key.startswith(("@", "_")): + metadata[key] = value + else: + content[str(key)] = value + + score = float(metadata.get("@search.score", 0.0)) + results.append(SearchResult(score=score, content=content, metadata=metadata)) except Exception as e: logger.warning(f"Error processing search document: {e}") continue - metadata: Dict[str, Any] = {} - content: Dict[str, Any] = {} - for key, value in doc_dict.items(): - key_str: str = str(key) - if key_str.startswith("@") or key_str.startswith("_"): - metadata[key_str] = value - else: - content[key_str] = value - - score: float = 0.0 - if "@search.score" in doc_dict: - score = float(doc_dict["@search.score"]) - - result = SearchResult( - score=score, - content=content, - metadata=metadata, - ) - results.append(result) - if self.search_config.enable_caching: - cache_key = f"{text_query}_{self.search_config.top}" self._cache[cache_key] = {"results": results, "timestamp": time.time()} - return SearchResults( - results=[SearchResult(score=r.score, content=r.content, metadata=r.metadata) for r in results] - ) + return SearchResults(results=results) + + except asyncio.CancelledError: + raise except Exception as e: + error_msg = str(e) if isinstance(e, HttpResponseError): if hasattr(e, "message") and e.message: - if "401 unauthorized" in e.message.lower() or "access denied" in e.message.lower(): - raise ValueError( - f"Authentication failed: {e.message}. Please check your API key and credentials." - ) from e - elif "500" in e.message: - raise ValueError(f"Error from Azure AI Search: {e.message}") from e - else: - raise ValueError(f"Error from Azure AI Search: {e.message}") from e - - if hasattr(self, "_name") and self._name == "test_search": - if ( - hasattr(self, "_credential") - and isinstance(self._credential, AzureKeyCredential) - and self._credential.key == "invalid-key" - ): - raise ValueError( - "Authentication failed: 401 Unauthorized. Please check your API key and credentials." - ) from e - elif "invalid status" in str(e).lower(): - raise ValueError( - "Error from Azure AI Search: 500 Internal Server Error: Something went wrong" - ) from e + error_msg = e.message - error_msg = str(e) if "not found" in error_msg.lower(): - raise ValueError( - f"Index '{self.search_config.index_name}' not found. Please check the index name and try again." - ) from e + raise ValueError(f"Index '{self.search_config.index_name}' not found.") from e elif "unauthorized" in error_msg.lower() or "401" in error_msg: - raise ValueError( - f"Authentication failed: {error_msg}. Please check your API key and credentials." - ) from e + raise ValueError(f"Authentication failed: {error_msg}") from e else: raise ValueError(f"Error from Azure AI Search: {error_msg}") from e - @abstractmethod - async def _get_embedding(self, query: str) -> List[float]: - """Generate embedding vector for the query text. - - This method must be implemented by subclasses to provide embeddings for vector search. - - Args: - query (str): The text to generate embeddings for. - - Returns: - List[float]: The embedding vector as a list of floats. - """ - pass - - def _to_config(self) -> Any: - """Get the tool configuration. - - Returns: - Any: The search configuration object - """ + def _to_config(self) -> AzureAISearchConfig: + """Convert the current instance to a configuration object.""" return self.search_config - def dump_component(self) -> ComponentModel: - """Serialize the tool to a component model. - - Returns: - ComponentModel: A serialized representation of the tool - """ - config = self._to_config() - return ComponentModel( - provider="autogen_ext.tools.azure.BaseAzureAISearchTool", - config=config.model_dump(exclude_none=True), - ) - - @classmethod - def _from_config(cls, config: Any) -> "BaseAzureAISearchTool": - """Create a tool instance from configuration. - - Args: - config (Any): The configuration object containing tool settings - - Returns: - BaseAzureAISearchTool: An initialized instance of the search tool - """ - query_type_str = getattr(config, "query_type", "keyword") - - query_type_mapping = { - "keyword": "keyword", - "simple": "fulltext", - "fulltext": "fulltext", - "vector": "vector", - "semantic": "semantic", - } - - query_type = cast( - Literal["keyword", "fulltext", "vector", "semantic"], query_type_mapping.get(query_type_str, "vector") - ) - - openai_client_attr = getattr(config, "openai_client", None) - if openai_client_attr is None: - raise ValueError("openai_client must be provided in config") - - embedding_model_attr = getattr(config, "embedding_model", "") - if not embedding_model_attr: - raise ValueError("embedding_model must be specified in config") - - # If query_type="semantic", you must provide a valid semantic_config_name. - # If query_type is anything else, semantic_config_name is ignored. - return cls( - name=getattr(config, "name", ""), - endpoint=getattr(config, "endpoint", ""), - index_name=getattr(config, "index_name", ""), - credential=getattr(config, "credential", {}), - description=getattr(config, "description", None), - api_version=getattr(config, "api_version", "2023-11-01"), - query_type=query_type, - search_fields=getattr(config, "search_fields", None), - select_fields=getattr(config, "select_fields", None), - vector_fields=getattr(config, "vector_fields", None), - top=getattr(config, "top", None), - filter=getattr(config, "filter", None), - semantic_config_name=getattr(config, "semantic_config_name", None), - enable_caching=getattr(config, "enable_caching", False), - cache_ttl_seconds=getattr(config, "cache_ttl_seconds", 300), - ) - - @overload - @classmethod - def load_component( - cls, model: Union[ComponentModel, Dict[str, Any]], expected: None = None - ) -> "BaseAzureAISearchTool": ... - - @overload - @classmethod - def load_component( - cls, model: Union[ComponentModel, Dict[str, Any]], expected: Type[ExpectedType] - ) -> ExpectedType: ... - - @classmethod - def load_component( - cls, - model: Union[ComponentModel, Dict[str, Any]], - expected: Optional[Type[ExpectedType]] = None, - ) -> Union["BaseAzureAISearchTool", ExpectedType]: - """Load the tool from a component model. - - Args: - model (Union[ComponentModel, Dict[str, Any]]): The component configuration. - expected (Optional[Type[ExpectedType]]): Optional component class for deserialization. - - Returns: - Union[BaseAzureAISearchTool, ExpectedType]: An instance of the tool. - - Raises: - ValueError: If the component configuration is invalid. - """ - if expected is not None and not issubclass(expected, BaseAzureAISearchTool): - raise TypeError(f"Cannot create instance of {expected} from AzureAISearchConfig") - - target_class = expected if expected is not None else cls - assert hasattr(target_class, "_from_config"), f"{target_class} has no _from_config method" - - if isinstance(model, ComponentModel) and hasattr(model, "config"): - config_dict = model.config - elif isinstance(model, dict): - config_dict = model - else: - raise ValueError(f"Invalid component configuration: {model}") - - config = AzureAISearchConfig(**config_dict) - - tool = target_class._from_config(config) - if expected is None: - return tool - return cast(ExpectedType, tool) - @property def schema(self) -> ToolSchema: """Return the schema for the tool.""" @@ -625,245 +600,158 @@ def schema(self) -> ToolSchema: } def return_value_as_string(self, value: SearchResults) -> str: - """Convert the search results to a string representation. - - This method is used to format the search results in a way that's suitable - for display to the user or for consumption by language models. - - Args: - value (List[SearchResult]): The search results to convert. - - Returns: - str: A formatted string representation of the search results. - """ + """Convert the search results to a string representation.""" if not value.results: return "No results found." result_strings: List[str] = [] for i, result in enumerate(value.results, 1): - content_str = ", ".join(f"{k}: {v}" for k, v in result.content.items()) + content_items = [f"{k}: {str(v) if v is not None else 'None'}" for k, v in result.content.items()] + content_str = ", ".join(content_items) result_strings.append(f"Result {i} (Score: {result.score:.2f}): {content_str}") return "\n".join(result_strings) + @classmethod + def _validate_config( + cls, config_dict: Dict[str, Any], search_type: Literal["full_text", "vector", "hybrid"] + ) -> None: + """Validate configuration for specific search types.""" + credential = config_dict.get("credential") + if isinstance(credential, str): + raise TypeError("Credential must be AzureKeyCredential, AsyncTokenCredential, or a valid dict") + if isinstance(credential, dict) and "api_key" not in credential: + raise ValueError("If credential is a dict, it must contain an 'api_key' key") + + try: + _ = AzureAISearchConfig(**config_dict) + except Exception as e: + raise ValueError(f"Invalid configuration: {str(e)}") from e + + if search_type == "vector": + vector_fields = config_dict.get("vector_fields") + if not vector_fields or len(vector_fields) == 0: + raise ValueError("vector_fields must contain at least one field name for vector search") + + elif search_type == "hybrid": + vector_fields = config_dict.get("vector_fields") + search_fields = config_dict.get("search_fields") + + if not vector_fields or len(vector_fields) == 0: + raise ValueError("vector_fields must contain at least one field name for hybrid search") + + if not search_fields or len(search_fields) == 0: + raise ValueError("search_fields must contain at least one field name for hybrid search") + + @classmethod + @abstractmethod + def _from_config(cls, config: AzureAISearchConfig) -> "BaseAzureAISearchTool": + """Create a tool instance from a configuration object. + + This is an abstract method that must be implemented by subclasses. + """ + if cls is BaseAzureAISearchTool: + raise NotImplementedError( + "BaseAzureAISearchTool is an abstract base class and cannot be instantiated directly. " + "Use a concrete implementation like AzureAISearchTool." + ) + raise NotImplementedError("Subclasses must implement _from_config") + + @abstractmethod + async def _get_embedding(self, query: str) -> List[float]: + """Generate embedding vector for the query text.""" + raise NotImplementedError("Subclasses must implement _get_embedding") + _allow_private_constructor = ContextVar("_allow_private_constructor", default=False) -class AzureAISearchTool(BaseAzureAISearchTool): +class AzureAISearchTool(EmbeddingProviderMixin, BaseAzureAISearchTool): """Azure AI Search tool for querying Azure search indexes. This tool provides a simplified interface for querying Azure AI Search indexes using - various search methods. The tool supports four main search types: - - 1. Keyword Search: Traditional text-based search using Azure's text analysis - 2. Full-Text Search: Enhanced text search with language-specific analyzers - 3. Vector Search: Semantic similarity search using vector embeddings - 4. Hybrid Search: Combines fulltext and vector search for comprehensive results - - You should use the factory methods to create instances for specific search types: - - create_keyword_search() - - create_full_text_search() - - create_vector_search() - - create_hybrid_search() - """ + various search methods. It's recommended to use the factory methods to create + instances tailored for specific search types: - def __init__( - self, - name: str, - endpoint: str, - index_name: str, - credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]], - query_type: Literal["keyword", "fulltext", "vector", "semantic"], - search_fields: Optional[List[str]] = None, - select_fields: Optional[List[str]] = None, - vector_fields: Optional[List[str]] = None, - filter: Optional[str] = None, - top: Optional[int] = 5, - **kwargs: Any, - ) -> None: - if not _allow_private_constructor.get(): - raise RuntimeError( - "Constructor is private. Use factory methods like create_keyword_search(), " - "create_vector_search(), or create_hybrid_search() instead." - ) + 1. **Full-Text Search**: For traditional keyword-based searches, Lucene queries, or + semantically re-ranked results. + - Use `AzureAISearchTool.create_full_text_search()` + - Supports `query_type`: "simple" (keyword), "full" (Lucene), "semantic". - super().__init__( - name=name, - endpoint=endpoint, - index_name=index_name, - credential=credential, - query_type=query_type, - search_fields=search_fields, - select_fields=select_fields, - vector_fields=vector_fields, - filter=filter, - top=top, - **kwargs, - ) + 2. **Vector Search**: For pure similarity searches based on vector embeddings. + - Use `AzureAISearchTool.create_vector_search()` - @classmethod - @overload - def load_component( - cls, model: Union[ComponentModel, Dict[str, Any]], expected: None = None - ) -> "AzureAISearchTool": ... + 3. **Hybrid Search**: For combining vector search with full-text or semantic search + to get the benefits of both. + - Use `AzureAISearchTool.create_hybrid_search()` + - The text component can be "simple", "full", or "semantic" via the `query_type` parameter. - @classmethod - @overload - def load_component( - cls, model: Union[ComponentModel, Dict[str, Any]], expected: Type[ExpectedType] - ) -> ExpectedType: ... + Each factory method configures the tool with appropriate defaults and validations + for the chosen search strategy. + + .. warning:: + If you set `query_type="semantic"`, you must also provide a valid `semantic_config_name`. + This configuration must be set up in your Azure AI Search index beforehand. + """ + + component_provider_override = "autogen_ext.tools.azure.AzureAISearchTool" @classmethod - def load_component( - cls, model: Union[ComponentModel, Dict[str, Any]], expected: Optional[Type[ExpectedType]] = None - ) -> Union["AzureAISearchTool", ExpectedType]: - """Load a component from a component model. + def _from_config(cls, config: AzureAISearchConfig) -> "AzureAISearchTool": + """Create a tool instance from a configuration object. Args: - model: The component model or dictionary with configuration - expected: Optional expected return type + config: The configuration object with tool settings Returns: - An initialized AzureAISearchTool instance + AzureAISearchTool: An initialized tool instance """ token = _allow_private_constructor.set(True) try: - if isinstance(model, dict): - model = ComponentModel(**model) - - config = model.config - - query_type_str = config.get("query_type", "keyword") - - query_type_mapping = { - "keyword": "keyword", - "simple": "fulltext", - "fulltext": "fulltext", - "vector": "vector", - "semantic": "semantic", - } - - query_type = cast( - Literal["keyword", "fulltext", "vector", "semantic"], query_type_mapping.get(query_type_str, "vector") - ) - instance = cls( - name=config.get("name", ""), - endpoint=config.get("endpoint", ""), - index_name=config.get("index_name", ""), - credential=config.get("credential", {}), - query_type=query_type, - search_fields=config.get("search_fields"), - select_fields=config.get("select_fields"), - vector_fields=config.get("vector_fields"), - top=config.get("top"), - filter=config.get("filter"), - enable_caching=config.get("enable_caching", False), - cache_ttl_seconds=config.get("cache_ttl_seconds", 300), + name=config.name, + description=config.description or "", + endpoint=config.endpoint, + index_name=config.index_name, + credential=config.credential, + api_version=config.api_version, + query_type=config.query_type, + search_fields=config.search_fields, + select_fields=config.select_fields, + vector_fields=config.vector_fields, + top=config.top, + filter=config.filter, + semantic_config_name=config.semantic_config_name, + enable_caching=config.enable_caching, + cache_ttl_seconds=config.cache_ttl_seconds, + embedding_provider=config.embedding_provider, + embedding_model=config.embedding_model, + openai_api_key=config.openai_api_key, + openai_api_version=config.openai_api_version, + openai_endpoint=config.openai_endpoint, ) - - if expected is not None: - return cast(ExpectedType, instance) return instance finally: _allow_private_constructor.reset(token) @classmethod - def _validate_common_params(cls, name: str, endpoint: str, index_name: str, credential: Any) -> None: - """Validate common parameters across all factory methods. - - Args: - name: Tool name - endpoint: Azure Search endpoint URL - index_name: Name of search index - credential: Authentication credentials - - Raises: - ValueError: If any parameter is invalid - """ - if not endpoint or not endpoint.startswith(("http://", "https://")): - raise ValueError("endpoint must be a valid URL starting with http:// or https://") - - if not index_name: - raise ValueError("index_name cannot be empty") - - if not name: - raise ValueError("name cannot be empty") - - if not credential: - raise ValueError("credential cannot be None") - - @classmethod - def create_keyword_search( - cls, - name: str, - endpoint: str, - index_name: str, - credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]], - search_fields: Optional[List[str]] = None, - select_fields: Optional[List[str]] = None, - filter: Optional[str] = None, - top: Optional[int] = 5, - **kwargs: Any, + def _create_from_params( + cls, config_dict: Dict[str, Any], search_type: Literal["full_text", "vector", "hybrid"] ) -> "AzureAISearchTool": - """Factory method to create a keyword search tool. - - Keyword search performs traditional text-based search, good for finding documents - containing specific terms or exact matches to your query. + """Private helper to create an instance from parameters after validation. Args: - name (str): The name of the tool - endpoint (str): The URL of your Azure AI Search service - index_name (str): The name of the search index - credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): Authentication credentials - search_fields (Optional[List[str]]): Fields to search within for text search - select_fields (Optional[List[str]]): Fields to include in results - filter (Optional[str]): OData filter expression to filter results - top (Optional[int]): Maximum number of results to return - **kwargs (Any): Additional configuration options + config_dict: Dictionary with configuration parameters + search_type: Type of search for validation Returns: - An initialized keyword search tool - - Example Usage: - .. code-block:: python - - # type: ignore - # Example of using keyword search with Azure AI Search - from autogen_ext.tools.azure import AzureAISearchTool - from azure.core.credentials import AzureKeyCredential - - # Create a keyword search tool - keyword_search = AzureAISearchTool.create_keyword_search( - name="keyword_search", - endpoint="https://your-service.search.windows.net", - index_name="your-index", - credential=AzureKeyCredential("your-api-key"), - search_fields=["title", "content"], - select_fields=["id", "title", "content", "category"], - top=10, - ) - - # The search tool can be used with an Agent - # assistant = Agent("assistant", tools=[keyword_search]) + Configured AzureAISearchTool instance """ - cls._validate_common_params(name, endpoint, index_name, credential) + cls._validate_config(config_dict, search_type) token = _allow_private_constructor.set(True) try: - return cls( - name=name, - endpoint=endpoint, - index_name=index_name, - credential=credential, - query_type="keyword", - search_fields=search_fields, - select_fields=select_fields, - filter=filter, - top=top, - **kwargs, - ) + return cls(**config_dict) finally: _allow_private_constructor.reset(token) @@ -873,72 +761,114 @@ def create_full_text_search( name: str, endpoint: str, index_name: str, - credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]], + credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]], + description: Optional[str] = None, + api_version: Optional[str] = None, + query_type: Literal["simple", "full", "semantic"] = "simple", search_fields: Optional[List[str]] = None, select_fields: Optional[List[str]] = None, - filter: Optional[str] = None, top: Optional[int] = 5, - **kwargs: Any, + filter: Optional[str] = None, + semantic_config_name: Optional[str] = None, + enable_caching: bool = False, + cache_ttl_seconds: int = 300, ) -> "AzureAISearchTool": - """Factory method to create a full-text search tool. + """Create a tool for traditional text-based searches. - Full-text search uses advanced text analysis (stemming, lemmatization, etc.) - to provide more comprehensive text matching than basic keyword search. + This factory method creates an AzureAISearchTool optimized for full-text search, + supporting keyword matching, Lucene syntax, and semantic search capabilities. Args: - name (str): The name of the tool - endpoint (str): The URL of your Azure AI Search service - index_name (str): The name of the search index - credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): Authentication credentials - search_fields (Optional[List[str]]): Fields to search within - select_fields (Optional[List[str]]): Fields to include in results - filter (Optional[str]): OData filter expression to filter results - top (Optional[int]): Maximum number of results to return - **kwargs (Any): Additional configuration options + name: The name of this tool instance + endpoint: The full URL of your Azure AI Search service + index_name: Name of the search index to query + credential: Azure credential for authentication (API key or token) + description: Optional description explaining the tool's purpose + api_version: Azure AI Search API version to use + query_type: Type of text search to perform: + + • **simple** : Basic keyword search that matches exact terms and their variations + • **full**: Advanced search using Lucene query syntax for complex queries + • **semantic**: AI-powered search that understands meaning and context, providing enhanced relevance ranking + search_fields: Fields to search within documents + select_fields: Fields to return in search results + top: Maximum number of results to return (default: 5) + filter: OData filter expression to refine search results + semantic_config_name: Semantic configuration name (required for semantic query_type) + enable_caching: Whether to cache search results + cache_ttl_seconds: How long to cache results in seconds Returns: - An initialized full-text search tool + An initialized AzureAISearchTool for full-text search - Example Usage: + Example: .. code-block:: python - # type: ignore - # Example of using full-text search with Azure AI Search - from autogen_ext.tools.azure import AzureAISearchTool from azure.core.credentials import AzureKeyCredential + from autogen_ext.tools.azure import AzureAISearchTool + + # Basic keyword search + tool = AzureAISearchTool.create_full_text_search( + name="doc-search", + endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint + index_name="", # Name of your search index + credential=AzureKeyCredential(""), # Your Azure AI Search admin key + query_type="simple", # Enable keyword search + search_fields=["content", "title"], # Required: fields to search within + select_fields=["content", "title", "url"], # Optional: fields to return + top=5, + ) - # Create a full-text search tool - full_text_search = AzureAISearchTool.create_full_text_search( - name="document_search", - endpoint="https://your-search-service.search.windows.net", - index_name="your-index", - credential=AzureKeyCredential("your-api-key"), - search_fields=["title", "content"], - select_fields=["title", "content", "category", "url"], - top=10, + # full text (Lucene query) search + full_text_tool = AzureAISearchTool.create_full_text_search( + name="doc-search", + endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint + index_name="", # Name of your search index + credential=AzureKeyCredential(""), # Your Azure AI Search admin key + query_type="full", # Enable Lucene query syntax + search_fields=["content", "title"], # Required: fields to search within + select_fields=["content", "title", "url"], # Optional: fields to return + top=5, + ) + + # Semantic search with re-ranking + # Note: Make sure your index has semantic configuration enabled + semantic_tool = AzureAISearchTool.create_full_text_search( + name="semantic-search", + endpoint="https://your-search.search.windows.net", + index_name="", + credential=AzureKeyCredential(""), + query_type="semantic", # Enable semantic ranking + semantic_config_name="", # Required for semantic search + search_fields=["content", "title"], # Required: fields to search within + select_fields=["content", "title", "url"], # Optional: fields to return + top=5, ) # The search tool can be used with an Agent - # assistant = Agent("assistant", tools=[full_text_search]) + # assistant = Agent("assistant", tools=[semantic_tool]) """ - cls._validate_common_params(name, endpoint, index_name, credential) + if query_type == "semantic" and not semantic_config_name: + raise ValueError("semantic_config_name is required when query_type is 'semantic'") + + config_dict = { + "name": name, + "endpoint": endpoint, + "index_name": index_name, + "credential": credential, + "description": description, + "api_version": api_version or DEFAULT_API_VERSION, + "query_type": query_type, + "search_fields": search_fields, + "select_fields": select_fields, + "top": top, + "filter": filter, + "semantic_config_name": semantic_config_name, + "enable_caching": enable_caching, + "cache_ttl_seconds": cache_ttl_seconds, + } - token = _allow_private_constructor.set(True) - try: - return cls( - name=name, - endpoint=endpoint, - index_name=index_name, - credential=credential, - query_type="fulltext", - search_fields=search_fields, - select_fields=select_fields, - filter=filter, - top=top, - **kwargs, - ) - finally: - _allow_private_constructor.reset(token) + return cls._create_from_params(config_dict, "full_text") @classmethod def create_vector_search( @@ -946,76 +876,128 @@ def create_vector_search( name: str, endpoint: str, index_name: str, - credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]], + credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]], vector_fields: List[str], + description: Optional[str] = None, + api_version: Optional[str] = None, select_fields: Optional[List[str]] = None, + top: int = 5, filter: Optional[str] = None, - top: Optional[int] = 5, - **kwargs: Any, + enable_caching: bool = False, + cache_ttl_seconds: int = 300, + embedding_provider: Optional[str] = None, + embedding_model: Optional[str] = None, + openai_api_key: Optional[str] = None, + openai_api_version: Optional[str] = None, + openai_endpoint: Optional[str] = None, ) -> "AzureAISearchTool": - """Factory method to create a vector search tool. + """Create a tool for pure vector/similarity search. - Vector search uses embedding vectors to find semantically similar content, enabling - the discovery of related information even when different terminology is used. + This factory method creates an AzureAISearchTool optimized for vector search, + allowing for semantic similarity-based matching using vector embeddings. Args: - name (str): The name of the tool - endpoint (str): The URL of your Azure AI Search service - index_name (str): The name of the search index - credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): Authentication credentials - vector_fields (List[str]): Fields containing vector embeddings for similarity search - select_fields (Optional[List[str]]): Fields to include in results - filter (Optional[str]): OData filter expression to filter results - top (Optional[int]): Maximum number of results to return - **kwargs (Any): Additional configuration options + name: The name of this tool instance + endpoint: The full URL of your Azure AI Search service + index_name: Name of the search index to query + credential: Azure credential for authentication (API key or token) + vector_fields: Fields to use for vector search (required) + description: Optional description explaining the tool's purpose + api_version: Azure AI Search API version to use + select_fields: Fields to return in search results + top: Maximum number of results to return / k in k-NN (default: 5) + filter: OData filter expression to refine search results + enable_caching: Whether to cache search results + cache_ttl_seconds: How long to cache results in seconds + embedding_provider: Provider for client-side embeddings (e.g., 'azure_openai', 'openai') + embedding_model: Model for client-side embeddings (e.g., 'text-embedding-ada-002') + openai_api_key: API key for OpenAI/Azure OpenAI embeddings + openai_api_version: API version for Azure OpenAI embeddings + openai_endpoint: Endpoint URL for Azure OpenAI embeddings Returns: - An initialized vector search tool + An initialized AzureAISearchTool for vector search + + Raises: + ValueError: If vector_fields is empty + ValueError: If embedding_provider is 'azure_openai' without openai_endpoint + ValueError: If required parameters are missing or invalid Example Usage: .. code-block:: python - # type: ignore - # Example of using vector search with Azure AI Search - from autogen_ext.tools.azure import AzureAISearchTool from azure.core.credentials import AzureKeyCredential + from autogen_ext.tools.azure import AzureAISearchTool - # Create a vector search tool - vector_search = AzureAISearchTool.create_vector_search( - name="vector_search", - endpoint="https://your-search-service.search.windows.net", - index_name="your-index", - credential=AzureKeyCredential("your-api-key"), - vector_fields=["embedding"], - select_fields=["title", "content", "url"], + # Vector search with service-side vectorization + tool = AzureAISearchTool.create_vector_search( + name="vector-search", + endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint + index_name="", # Name of your search index + credential=AzureKeyCredential(""), # Your Azure AI Search admin key + vector_fields=["content_vector"], # Your vector field name + select_fields=["content", "title", "url"], # Fields to return in results top=5, ) - # The search tool can be used with an Agent - # assistant = Agent("assistant", tools=[vector_search]) + # Vector search with Azure OpenAI embeddings + azure_openai_tool = AzureAISearchTool.create_vector_search( + name="azure-openai-vector-search", + endpoint="https://your-search.search.windows.net", + index_name="", + credential=AzureKeyCredential(""), + vector_fields=["content_vector"], + embedding_provider="azure_openai", # Use Azure OpenAI for embeddings + embedding_model="text-embedding-ada-002", # Embedding model to use + openai_endpoint="https://your-openai.openai.azure.com", # Your Azure OpenAI endpoint + openai_api_key="", # Your Azure OpenAI key + openai_api_version="2024-02-15-preview", # Azure OpenAI API version + select_fields=["content", "title", "url"], # Fields to return in results + top=5, + ) - """ - cls._validate_common_params(name, endpoint, index_name, credential) + # Vector search with OpenAI embeddings + openai_tool = AzureAISearchTool.create_vector_search( + name="openai-vector-search", + endpoint="https://your-search.search.windows.net", + index_name="", + credential=AzureKeyCredential(""), + vector_fields=["content_vector"], + embedding_provider="openai", # Use OpenAI for embeddings + embedding_model="text-embedding-ada-002", # Embedding model to use + openai_api_key="", # Your OpenAI API key + select_fields=["content", "title", "url"], # Fields to return in results + top=5, + ) - if not vector_fields or len(vector_fields) == 0: - raise ValueError("vector_fields must contain at least one field name") + # Use the tool with an Agent + # assistant = Agent("assistant", tools=[azure_openai_tool]) + """ + if embedding_provider == "azure_openai" and not openai_endpoint: + raise ValueError("openai_endpoint is required when embedding_provider is 'azure_openai'") + + config_dict = { + "name": name, + "endpoint": endpoint, + "index_name": index_name, + "credential": credential, + "description": description, + "api_version": api_version or DEFAULT_API_VERSION, + "query_type": "vector", + "select_fields": select_fields, + "vector_fields": vector_fields, + "top": top, + "filter": filter, + "enable_caching": enable_caching, + "cache_ttl_seconds": cache_ttl_seconds, + "embedding_provider": embedding_provider, + "embedding_model": embedding_model, + "openai_api_key": openai_api_key, + "openai_api_version": openai_api_version, + "openai_endpoint": openai_endpoint, + } - token = _allow_private_constructor.set(True) - try: - return cls( - name=name, - endpoint=endpoint, - index_name=index_name, - credential=credential, - query_type="vector", - vector_fields=vector_fields, - select_fields=select_fields, - filter=filter, - top=top, - **kwargs, - ) - finally: - _allow_private_constructor.reset(token) + return cls._create_from_params(config_dict, "vector") @classmethod def create_hybrid_search( @@ -1023,171 +1005,131 @@ def create_hybrid_search( name: str, endpoint: str, index_name: str, - credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]], + credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]], vector_fields: List[str], - search_fields: Optional[List[str]] = None, + search_fields: List[str], + description: Optional[str] = None, + api_version: Optional[str] = None, + query_type: Literal["simple", "full", "semantic"] = "simple", select_fields: Optional[List[str]] = None, + top: int = 5, filter: Optional[str] = None, - top: Optional[int] = 5, - **kwargs: Any, + semantic_config_name: Optional[str] = None, + enable_caching: bool = False, + cache_ttl_seconds: int = 300, + embedding_provider: Optional[str] = None, + embedding_model: Optional[str] = None, + openai_api_key: Optional[str] = None, + openai_api_version: Optional[str] = None, + openai_endpoint: Optional[str] = None, ) -> "AzureAISearchTool": - """Factory method to create a hybrid search tool (text + vector). + """Create a tool that combines vector and text search capabilities. - Hybrid search combines text search (fulltext or semantic) with vector similarity - search to provide more comprehensive results. This is the recommended entrypoint for hybrid (text + vector) search. - The query_type will be 'semantic' if semantic_config_name is provided, otherwise 'fulltext'. + This factory method creates an AzureAISearchTool configured for hybrid search, + which combines the benefits of vector similarity and traditional text search. Args: - name (str): The name of the tool - endpoint (str): The URL of your Azure AI Search service - index_name (str): The name of the search index - credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): Authentication credentials - vector_fields (List[str]): Fields containing vector embeddings for similarity search - search_fields (Optional[List[str]]): Fields to search within for text search - select_fields (Optional[List[str]]): Fields to include in results - filter (Optional[str]): OData filter expression to filter results - top (Optional[int]): Maximum number of results to return - **kwargs (Any): Additional configuration options + name: The name of this tool instance + endpoint: The full URL of your Azure AI Search service + index_name: Name of the search index to query + credential: Azure credential for authentication (API key or token) + vector_fields: Fields to use for vector search (required) + search_fields: Fields to use for text search (required) + description: Optional description explaining the tool's purpose + api_version: Azure AI Search API version to use + query_type: Type of text search to perform: + + • **simple**: Basic keyword search that matches exact terms and their variations + • **full**: Advanced search using Lucene query syntax for complex queries + • **semantic**: AI-powered search that understands meaning and context, providing enhanced relevance ranking + select_fields: Fields to return in search results + top: Maximum number of results to return (default: 5) + filter: OData filter expression to refine search results + semantic_config_name: Semantic configuration name (required if query_type="semantic") + enable_caching: Whether to cache search results + cache_ttl_seconds: How long to cache results in seconds + embedding_provider: Provider for client-side embeddings (e.g., 'azure_openai', 'openai') + embedding_model: Model for client-side embeddings (e.g., 'text-embedding-ada-002') + openai_api_key: API key for OpenAI/Azure OpenAI embeddings + openai_api_version: API version for Azure OpenAI embeddings + openai_endpoint: Endpoint URL for Azure OpenAI embeddings Returns: - An initialized hybrid search tool + An initialized AzureAISearchTool for hybrid search - Example Usage: + Raises: + ValueError: If vector_fields or search_fields is empty + ValueError: If query_type is "semantic" without semantic_config_name + ValueError: If embedding_provider is 'azure_openai' without openai_endpoint + ValueError: If required parameters are missing or invalid + + Example: .. code-block:: python - # type: ignore - # Example of using hybrid search with Azure AI Search - from autogen_ext.tools.azure import AzureAISearchTool from azure.core.credentials import AzureKeyCredential + from autogen_ext.tools.azure import AzureAISearchTool - # Create a hybrid search tool - hybrid_search = AzureAISearchTool.create_hybrid_search( - name="hybrid_search", - endpoint="https://your-search-service.search.windows.net", - index_name="your-index", - credential=AzureKeyCredential("your-api-key"), - vector_fields=["embedding_field"], - search_fields=["title", "content"], - select_fields=["title", "content", "url", "date"], - top=10, + # Basic hybrid search with service-side vectorization + tool = AzureAISearchTool.create_hybrid_search( + name="hybrid-search", + endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint + index_name="", # Name of your search index + credential=AzureKeyCredential(""), # Your Azure AI Search admin key + vector_fields=["content_vector"], # Your vector field name + search_fields=["content", "title"], # Your searchable fields + top=5, ) - # The search tool can be used with an Agent - # assistant = Agent("researcher", tools=[hybrid_search]) - - - .. warning:: - - If you set ``query_type=\"semantic\"``, you must also provide a valid ``semantic_config_name``. - If you do not, the tool will default to the config name ``\"semantic\"``. - - """ - cls._validate_common_params(name, endpoint, index_name, credential) - - if not vector_fields or len(vector_fields) == 0: - raise ValueError("vector_fields must contain at least one field name") - - token = _allow_private_constructor.set(True) - try: - if kwargs.get("semantic_config_name"): - text_query_type = "semantic" - else: - text_query_type = "fulltext" - - from typing import cast - - return cls( - name=name, - endpoint=endpoint, - index_name=index_name, - credential=credential, - query_type=cast(Literal["keyword", "fulltext", "vector", "semantic"], text_query_type), - search_fields=search_fields, - select_fields=select_fields, - vector_fields=vector_fields, - filter=filter, - top=top, - **kwargs, - ) - finally: - _allow_private_constructor.reset(token) - - async def _get_embedding(self, query: str) -> List[float]: - """Generate embedding vector for the query text. - - This method handles generating embeddings for vector search functionality. - The embedding provider and model should be specified in the tool configuration. - - Args: - query (str): The text to generate embeddings for. - - Returns: - List[float]: The embedding vector as a list of floats. - - Raises: - ValueError: If the embedding configuration is missing or invalid. - """ - embedding_provider = getattr(self.search_config, "embedding_provider", None) - embedding_model = getattr(self.search_config, "embedding_model", None) - - if not embedding_provider or not embedding_model: - raise ValueError( - "To use vector search, you must provide embedding_provider and embedding_model in the configuration." - ) from None - - if embedding_provider.lower() == "azure_openai": - try: - from azure.identity import DefaultAzureCredential - from openai import AsyncAzureOpenAI - except ImportError: - raise ImportError( - "Azure OpenAI SDK is required for embedding generation. " - "Please install it with: uv add openai azure-identity" - ) from None - - api_key = None - if hasattr(self.search_config, "openai_api_key"): - api_key = self.search_config.openai_api_key - - api_version = getattr(self.search_config, "openai_api_version", "2023-05-15") - endpoint = getattr(self.search_config, "openai_endpoint", None) - - if not endpoint: - raise ValueError("OpenAI endpoint must be provided for Azure OpenAI embeddings") from None - - if api_key: - azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=endpoint) - else: - - def get_token() -> str: - credential = DefaultAzureCredential() - return credential.get_token("https://cognitiveservices.azure.com/.default").token - - azure_client = AsyncAzureOpenAI( - azure_ad_token_provider=get_token, api_version=api_version, azure_endpoint=endpoint + # Hybrid search with semantic ranking and Azure OpenAI embeddings + semantic_tool = AzureAISearchTool.create_hybrid_search( + name="semantic-hybrid-search", + endpoint="https://your-search.search.windows.net", + index_name="", + credential=AzureKeyCredential(""), + vector_fields=["content_vector"], + search_fields=["content", "title"], + query_type="semantic", # Enable semantic ranking + semantic_config_name="", # Your semantic config name + embedding_provider="azure_openai", # Use Azure OpenAI for embeddings + embedding_model="text-embedding-ada-002", # Embedding model to use + openai_endpoint="https://your-openai.openai.azure.com", # Your Azure OpenAI endpoint + openai_api_key="", # Your Azure OpenAI key + openai_api_version="2024-02-15-preview", # Azure OpenAI API version + select_fields=["content", "title", "url"], # Fields to return in results + filter="language eq 'en'", # Optional OData filter + top=5, ) - response = await azure_client.embeddings.create(model=embedding_model, input=query) - return response.data[0].embedding - - elif embedding_provider.lower() == "openai": - try: - from openai import AsyncOpenAI - except ImportError: - raise ImportError( - "OpenAI SDK is required for embedding generation. " "Please install it with: uv add openai" - ) from None - - api_key = None - if hasattr(self.search_config, "openai_api_key"): - api_key = self.search_config.openai_api_key - - openai_client = AsyncOpenAI(api_key=api_key) + # The search tool can be used with an Agent + # assistant = Agent("assistant", tools=[semantic_tool]) + """ + if query_type == "semantic" and not semantic_config_name: + raise ValueError("semantic_config_name is required when query_type is 'semantic'") + + if embedding_provider == "azure_openai" and not openai_endpoint: + raise ValueError("openai_endpoint is required when embedding_provider is 'azure_openai'") + + config_dict = { + "name": name, + "endpoint": endpoint, + "index_name": index_name, + "credential": credential, + "description": description, + "api_version": api_version or DEFAULT_API_VERSION, + "query_type": query_type, + "search_fields": search_fields, + "select_fields": select_fields, + "vector_fields": vector_fields, + "top": top, + "filter": filter, + "semantic_config_name": semantic_config_name, + "enable_caching": enable_caching, + "cache_ttl_seconds": cache_ttl_seconds, + "embedding_provider": embedding_provider, + "embedding_model": embedding_model, + "openai_api_key": openai_api_key, + "openai_api_version": openai_api_version, + "openai_endpoint": openai_endpoint, + } - response = await openai_client.embeddings.create(model=embedding_model, input=query) - return response.data[0].embedding - else: - raise ValueError( - f"Unsupported embedding provider: {embedding_provider}. " - "Currently supported providers are 'azure_openai' and 'openai'." - ) from None + return cls._create_from_params(config_dict, "hybrid") diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/azure/_config.py b/python/packages/autogen-ext/src/autogen_ext/tools/azure/_config.py index a27fdd6776af..38fa1fb156ad 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/azure/_config.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/azure/_config.py @@ -6,173 +6,180 @@ import logging from typing import ( - Any, - Dict, List, Literal, Optional, - Type, TypeVar, Union, ) -from azure.core.credentials import AzureKeyCredential, TokenCredential -from pydantic import BaseModel, Field, model_validator - -# Add explicit ignore for the specific model validator error -# pyright: reportArgumentType=false -# pyright: reportUnknownArgumentType=false -# pyright: reportUnknownVariableType=false +from azure.core.credentials import AzureKeyCredential +from azure.core.credentials_async import AsyncTokenCredential +from pydantic import BaseModel, Field, field_validator, model_validator T = TypeVar("T", bound="AzureAISearchConfig") logger = logging.getLogger(__name__) +QueryTypeLiteral = Literal["simple", "full", "semantic", "vector"] +DEFAULT_API_VERSION = "2023-10-01-preview" + class AzureAISearchConfig(BaseModel): - """Configuration for Azure AI Search tool. + """Configuration for Azure AI Search with validation. - This class defines the configuration parameters for :class:`AzureAISearchTool`. - It provides options for customizing search behavior including query types, - field selection, authentication, retry policies, and caching strategies. + This class defines the configuration parameters for Azure AI Search tools, including + authentication, search behavior, caching, and embedding settings. .. note:: - - This class requires the :code:`azure` extra for the :code:`autogen-ext` package. + This class requires the ``azure`` extra for the ``autogen-ext`` package. .. code-block:: bash pip install -U "autogen-ext[azure]" - Example: + .. note:: + **Prerequisites:** + + 1. An Azure AI Search service must be created in your Azure subscription. + 2. The search index must be properly configured for your use case: + + - For vector search: Index must have vector fields + - For semantic search: Index must have semantic configuration + - For hybrid search: Both vector fields and text fields must be configured + 3. Required packages: + + - Base functionality: ``azure-search-documents>=11.4.0`` + - For Azure OpenAI embeddings: ``openai azure-identity`` + - For OpenAI embeddings: ``openai`` + + Example Usage: .. code-block:: python from azure.core.credentials import AzureKeyCredential from autogen_ext.tools.azure import AzureAISearchConfig + # Basic configuration for full-text search config = AzureAISearchConfig( - name="doc_search", - endpoint="https://my-search.search.windows.net", - index_name="my-index", + name="doc-search", + endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint + index_name="", # Name of your search index + credential=AzureKeyCredential(""), # Your Azure AI Search admin key + query_type="simple", + search_fields=["content", "title"], # Update with your searchable fields + top=5, + ) + + # Configuration for vector search with Azure OpenAI embeddings + vector_config = AzureAISearchConfig( + name="vector-search", + endpoint="https://your-search.search.windows.net", + index_name="", credential=AzureKeyCredential(""), query_type="vector", - vector_fields=["embedding"], + vector_fields=["embedding"], # Update with your vector field name + embedding_provider="azure_openai", + embedding_model="text-embedding-ada-002", + openai_endpoint="https://your-openai.openai.azure.com", # Your Azure OpenAI endpoint + openai_api_key="", # Your Azure OpenAI key + top=5, ) - For more details, see: - * `Azure AI Search Overview `_ - * `Vector Search `_ - - Args: - name (str): Name for the tool instance, used to identify it in the agent's toolkit. - description (Optional[str]): Human-readable description of what this tool does and how to use it. - endpoint (str): The full URL of your Azure AI Search service, in the format - 'https://.search.windows.net'. - index_name (str): Name of the target search index in your Azure AI Search service. - The index must be pre-created and properly configured. - api_version (str): Azure AI Search REST API version to use. Defaults to '2023-11-01'. - Only change if you need specific features from a different API version. - credential (Union[AzureKeyCredential, TokenCredential]): Azure authentication credential: - - AzureKeyCredential: For API key authentication (admin/query key) - - TokenCredential: For Azure AD authentication (e.g., DefaultAzureCredential) - query_type (Literal["keyword", "fulltext", "vector", "semantic"]): The search query mode to use: - - 'keyword': Basic keyword search (default) - - 'fulltext': Full Lucene query syntax - - 'vector': Vector similarity search - - 'semantic': Semantic search using semantic configuration - search_fields (Optional[List[str]]): List of index fields to search within. If not specified, - searches all searchable fields. Example: ['title', 'content']. - select_fields (Optional[List[str]]): Fields to return in search results. If not specified, - returns all fields. Use to optimize response size. - vector_fields (Optional[List[str]]): Vector field names for vector search. Must be configured - in your search index as vector fields. Required for vector search. - top (Optional[int]): Maximum number of documents to return in search results. - Helps control response size and processing time. - retry_enabled (bool): Whether to enable retry policy for transient errors. Defaults to True. - retry_max_attempts (Optional[int]): Maximum number of retry attempts for failed requests. Defaults to 3. - retry_mode (Literal["fixed", "exponential"]): Retry backoff strategy: fixed or exponential. Defaults to "exponential". - enable_caching (bool): Whether to enable client-side caching of search results. Defaults to False. - cache_ttl_seconds (int): Time-to-live for cached search results in seconds. Defaults to 300 (5 minutes). - filter (Optional[str]): OData filter expression to refine search results. + # Configuration for hybrid search with semantic ranking + hybrid_config = AzureAISearchConfig( + name="hybrid-search", + endpoint="https://your-search.search.windows.net", + index_name="", + credential=AzureKeyCredential(""), + query_type="semantic", + semantic_config_name="", # Name of your semantic configuration + search_fields=["content", "title"], # Update with your search fields + vector_fields=["embedding"], # Update with your vector field name + embedding_provider="openai", + embedding_model="text-embedding-ada-002", + openai_api_key="", # Your OpenAI API key + top=5, + ) """ - name: str = Field(description="The name of the tool") - description: Optional[str] = Field(default=None, description="A description of the tool") - endpoint: str = Field(description="The endpoint URL for your Azure AI Search service") - index_name: str = Field(description="The name of the search index to query") - api_version: str = Field(default="2023-11-01", description="API version to use") - credential: Union[AzureKeyCredential, TokenCredential] = Field( - description="The credential to use for authentication" + name: str = Field(description="The name of this tool instance") + description: Optional[str] = Field(default=None, description="Description explaining the tool's purpose") + endpoint: str = Field(description="The full URL of your Azure AI Search service") + index_name: str = Field(description="Name of the search index to query") + credential: Union[AzureKeyCredential, AsyncTokenCredential] = Field( + description="Azure credential for authentication (API key or token)" ) - query_type: Literal["keyword", "fulltext", "vector", "semantic"] = Field( - default="keyword", - description="Type of query to perform (keyword for classic, fulltext for Lucene, vector for embedding, semantic for semantic/AI search)", + api_version: str = Field( + default=DEFAULT_API_VERSION, + description=f"Azure AI Search API version to use. Defaults to {DEFAULT_API_VERSION}.", ) - search_fields: Optional[List[str]] = Field(default=None, description="Optional list of fields to search in") - select_fields: Optional[List[str]] = Field(default=None, description="Optional list of fields to return in results") - vector_fields: Optional[List[str]] = Field( - default=None, description="Optional list of vector fields for vector search" + query_type: QueryTypeLiteral = Field( + default="simple", description="Type of search to perform: simple, full, semantic, or vector" ) - top: Optional[int] = Field(default=None, description="Optional number of results to return") - filter: Optional[str] = Field(default=None, description="Optional OData filter expression to refine search results") - - retry_enabled: bool = Field(default=True, description="Whether to enable retry policy for transient errors") - retry_max_attempts: Optional[int] = Field( - default=3, description="Maximum number of retry attempts for failed requests" + search_fields: Optional[List[str]] = Field(default=None, description="Fields to search within documents") + select_fields: Optional[List[str]] = Field(default=None, description="Fields to return in search results") + vector_fields: Optional[List[str]] = Field(default=None, description="Fields to use for vector search") + top: Optional[int] = Field( + default=None, description="Maximum number of results to return. For vector searches, acts as k in k-NN." ) - retry_mode: Literal["fixed", "exponential"] = Field( - default="exponential", - description="Retry backoff strategy: fixed or exponential", + filter: Optional[str] = Field(default=None, description="OData filter expression to refine search results") + semantic_config_name: Optional[str] = Field( + default=None, description="Semantic configuration name for enhanced results" ) - enable_caching: bool = Field( - default=False, - description="Whether to enable client-side caching of search results", - ) - cache_ttl_seconds: int = Field( - default=300, # 5 minutes - description="Time-to-live for cached search results in seconds", - ) + enable_caching: bool = Field(default=False, description="Whether to cache search results") + cache_ttl_seconds: int = Field(default=300, description="How long to cache results in seconds") embedding_provider: Optional[str] = Field( - default=None, - description="Name of embedding provider to use (e.g., 'azure_openai', 'openai')", - ) - embedding_model: Optional[str] = Field(default=None, description="Model name to use for generating embeddings") - embedding_dimension: Optional[int] = Field( - default=None, description="Dimension of embedding vectors produced by the model" + default=None, description="Name of embedding provider for client-side embeddings" ) + embedding_model: Optional[str] = Field(default=None, description="Model name for client-side embeddings") + openai_api_key: Optional[str] = Field(default=None, description="API key for OpenAI/Azure OpenAI embeddings") + openai_api_version: Optional[str] = Field(default=None, description="API version for Azure OpenAI embeddings") + openai_endpoint: Optional[str] = Field(default=None, description="Endpoint URL for Azure OpenAI embeddings") model_config = {"arbitrary_types_allowed": True} - @classmethod - @model_validator(mode="before") - def validate_credentials(cls: Type[T], data: Any) -> Any: - """Validate and convert credential data.""" - if not isinstance(data, dict): - return data - - result = {} - - for key, value in data.items(): - result[str(key)] = value - - if "credential" in result: - credential = result["credential"] - - if isinstance(credential, dict) and "api_key" in credential: - api_key = str(credential["api_key"]) - result["credential"] = AzureKeyCredential(api_key) - - return result - - def model_dump(self, **kwargs: Any) -> Dict[str, Any]: - """Custom model_dump to handle credentials.""" - result: Dict[str, Any] = super().model_dump(**kwargs) - - if isinstance(self.credential, AzureKeyCredential): - result["credential"] = {"type": "AzureKeyCredential"} - elif isinstance(self.credential, TokenCredential): - result["credential"] = {"type": "TokenCredential"} - - return result + @field_validator("endpoint") + def validate_endpoint(cls, v: str) -> str: + """Validate that the endpoint is a valid URL.""" + if not v.startswith(("http://", "https://")): + raise ValueError("endpoint must be a valid URL starting with http:// or https://") + return v + + @field_validator("query_type") + def normalize_query_type(cls, v: QueryTypeLiteral) -> QueryTypeLiteral: + """Normalize query type to standard values.""" + if not v: + return "simple" + + if isinstance(v, str) and v.lower() == "fulltext": + return "full" + + return v + + @field_validator("top") + def validate_top(cls, v: Optional[int]) -> Optional[int]: + """Ensure top is a positive integer if provided.""" + if v is not None and v <= 0: + raise ValueError("top must be a positive integer") + return v + + @model_validator(mode="after") + def validate_interdependent_fields(self) -> "AzureAISearchConfig": + """Validate interdependent fields after all fields have been parsed.""" + if self.query_type == "semantic" and not self.semantic_config_name: + raise ValueError("semantic_config_name must be provided when query_type is 'semantic'") + + if self.query_type == "vector" and not self.vector_fields: + raise ValueError("vector_fields must be provided for vector search") + + if ( + self.embedding_provider + and self.embedding_provider.lower() == "azure_openai" + and self.embedding_model + and not self.openai_endpoint + ): + raise ValueError("openai_endpoint must be provided for azure_openai embedding provider") + + return self diff --git a/python/packages/autogen-ext/tests/tools/azure/conftest.py b/python/packages/autogen-ext/tests/tools/azure/conftest.py index 4b4a974ff0cb..6d3a8569f127 100644 --- a/python/packages/autogen-ext/tests/tools/azure/conftest.py +++ b/python/packages/autogen-ext/tests/tools/azure/conftest.py @@ -1,7 +1,7 @@ """Test fixtures for Azure AI Search tool tests.""" import warnings -from typing import Any, Dict, Generator, List, Protocol, Type, TypeVar, Union +from typing import Any, Dict, Iterator, List, Protocol, TypeVar, Union from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -9,6 +9,17 @@ T = TypeVar("T") +try: + from azure.core.credentials import AzureKeyCredential, TokenCredential + + azure_sdk_available = True +except ImportError: + azure_sdk_available = False + +skip_if_no_azure_sdk = pytest.mark.skipif( + not azure_sdk_available, reason="Azure SDK components (azure-search-documents, azure-identity) not available" +) + class AccessTokenProtocol(Protocol): """Protocol matching Azure AccessToken.""" @@ -47,18 +58,13 @@ def get_token( return MockAccessToken("mock-token", 12345) -try: - from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential - - _access_token_type: Type[AccessToken] = AccessToken - azure_sdk_available = True -except ImportError: - AzureKeyCredential = MockAzureKeyCredential # type: ignore - TokenCredential = MockTokenCredential # type: ignore - _access_token_type = MockAccessToken # type: ignore - azure_sdk_available = False - -CredentialType = Union[AzureKeyCredential, TokenCredential, MockAzureKeyCredential, MockTokenCredential, Any] +CredentialType = Union[ + AzureKeyCredential, # pyright: ignore [reportPossiblyUnboundVariable] + TokenCredential, # pyright: ignore [reportPossiblyUnboundVariable] + MockAzureKeyCredential, + MockTokenCredential, + Any, +] needs_azure_sdk = pytest.mark.skipif(not azure_sdk_available, reason="Azure SDK not available") @@ -70,10 +76,14 @@ def get_token( @pytest.fixture -def mock_vectorized_query() -> Generator[MagicMock, None, None]: +def mock_vectorized_query() -> MagicMock: """Create a mock VectorizedQuery for testing.""" - with patch("azure.search.documents.models.VectorizedQuery") as mock: - yield mock + if azure_sdk_available: + from azure.search.documents.models import VectorizedQuery + + return MagicMock(spec=VectorizedQuery) + else: + return MagicMock() @pytest.fixture @@ -87,7 +97,7 @@ def test_config() -> ComponentModel: "endpoint": "https://test-search-service.search.windows.net", "index_name": "test-index", "api_version": "2023-10-01-Preview", - "credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, + "credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, # pyright: ignore [reportPossiblyUnboundVariable] "query_type": "keyword", "search_fields": ["content", "title"], "select_fields": ["id", "content", "title", "source"], @@ -106,7 +116,7 @@ def keyword_config() -> ComponentModel: "description": "Keyword search tool", "endpoint": "https://test-search-service.search.windows.net", "index_name": "test-index", - "credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, + "credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, # pyright: ignore [reportPossiblyUnboundVariable] "query_type": "keyword", "search_fields": ["content", "title"], "select_fields": ["id", "content", "title", "source"], @@ -125,7 +135,7 @@ def vector_config() -> ComponentModel: "endpoint": "https://test-search-service.search.windows.net", "index_name": "test-index", "api_version": "2023-10-01-Preview", - "credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, + "credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, # pyright: ignore [reportPossiblyUnboundVariable] "query_type": "vector", "vector_fields": ["embedding"], "select_fields": ["id", "content", "title", "source"], @@ -145,7 +155,7 @@ def hybrid_config() -> ComponentModel: "endpoint": "https://test-search-service.search.windows.net", "index_name": "test-index", "api_version": "2023-10-01-Preview", - "credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, + "credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, # pyright: ignore [reportPossiblyUnboundVariable] "query_type": "keyword", "search_fields": ["content", "title"], "vector_fields": ["embedding"], @@ -196,108 +206,18 @@ async def get_count(self) -> int: @pytest.fixture -def mock_search_client(mock_search_response: List[Dict[str, Any]]) -> tuple[MagicMock, Any]: - """Create a mock search client for testing.""" - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - - search_results = AsyncIterator(mock_search_response) - mock_client.search = MagicMock(return_value=search_results) - - patcher = patch("azure.search.documents.aio.SearchClient", return_value=mock_client) - - return mock_client, patcher - - -def test_validate_credentials_scenarios() -> None: - """Test all validate_credentials scenarios to ensure full code coverage.""" - import sys - - from autogen_ext.tools.azure._config import AzureAISearchConfig - - module_path = sys.modules[AzureAISearchConfig.__module__].__file__ - if module_path is not None: - assert "autogen-ext" in module_path - - data: Any = "not a dict" - result: Any = AzureAISearchConfig.validate_credentials(data) # type: ignore - assert result == data - - data_empty: Dict[str, Any] = {} - result_empty: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_empty) # type: ignore - assert isinstance(result_empty, dict) - - data_items: Dict[str, Any] = {"key1": "value1", "key2": "value2"} - result_items: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_items) # type: ignore - assert result_items["key1"] == "value1" - assert result_items["key2"] == "value2" - - data_with_api_key: Dict[str, Any] = { - "name": "test", - "endpoint": "https://test.search.windows.net", - "index_name": "test-index", - "credential": {"api_key": "test-key"}, - } - result_with_api_key: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_with_api_key) # type: ignore - - cred = result_with_api_key["credential"] # type: ignore - assert isinstance(cred, (AzureKeyCredential, MockAzureKeyCredential)) - assert hasattr(cred, "key") - assert cred.key == "test-key" # type: ignore - - credential: Any = AzureKeyCredential("test-key") - data_with_credential: Dict[str, Any] = { - "name": "test", - "endpoint": "https://test.search.windows.net", - "index_name": "test-index", - "credential": credential, - } - result_with_credential: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_with_credential) # type: ignore - assert result_with_credential["credential"] is credential - - data_without_api_key: Dict[str, Any] = { - "name": "test", - "endpoint": "https://test.search.windows.net", - "index_name": "test-index", - "credential": {"username": "test-user", "password": "test-pass"}, - } - result_without_api_key: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_without_api_key) # type: ignore - assert result_without_api_key["credential"] == {"username": "test-user", "password": "test-pass"} - - -def test_model_dump_scenarios() -> None: - """Test all model_dump scenarios to ensure full code coverage.""" - import sys - - from autogen_ext.tools.azure._config import AzureAISearchConfig - - module_path = sys.modules[AzureAISearchConfig.__module__].__file__ - if module_path is not None: - assert "autogen-ext" in module_path - - config = AzureAISearchConfig( - name="test", - endpoint="https://endpoint", - index_name="index", - credential=AzureKeyCredential("key"), # type: ignore - ) - result = config.model_dump() - assert result["credential"] == {"type": "AzureKeyCredential"} +def mock_search_client(mock_search_response: List[Dict[str, Any]]) -> Iterator[MagicMock]: + """Create a mock search client for testing, with the patch active.""" + mock_client_instance = MagicMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) - if azure_sdk_available: - from azure.core.credentials import AccessToken - from azure.core.credentials import TokenCredential as RealTokenCredential - - class TestTokenCredential(RealTokenCredential): - def get_token(self, *args: Any, **kwargs: Any) -> AccessToken: - """Override of get_token method that returns proper type.""" - return AccessToken("test-token", 12345) - - config = AzureAISearchConfig( - name="test", endpoint="https://endpoint", index_name="index", credential=TestTokenCredential() - ) - result = config.model_dump() - assert result["credential"] == {"type": "TokenCredential"} - else: - pytest.skip("Skipping TokenCredential test - Azure SDK not available") + search_results_iterator = AsyncIterator(mock_search_response) + mock_client_instance.search = MagicMock(return_value=search_results_iterator) + + patch_target = "autogen_ext.tools.azure._ai_search.SearchClient" + patcher = patch(patch_target, return_value=mock_client_instance) + + patcher.start() + yield mock_client_instance + patcher.stop() diff --git a/python/packages/autogen-ext/tests/tools/azure/test_ai_search_config.py b/python/packages/autogen-ext/tests/tools/azure/test_ai_search_config.py new file mode 100644 index 000000000000..ddcc26e36869 --- /dev/null +++ b/python/packages/autogen-ext/tests/tools/azure/test_ai_search_config.py @@ -0,0 +1,297 @@ +from typing import Any, Dict, cast + +import pytest +from autogen_ext.tools.azure._config import AzureAISearchConfig, QueryTypeLiteral +from azure.core.credentials import AzureKeyCredential +from pydantic import ValidationError + +from tests.tools.azure.conftest import azure_sdk_available + +skip_if_no_azure_sdk = pytest.mark.skipif( + not azure_sdk_available, reason="Azure SDK components (azure-search-documents, azure-identity) not available" +) + +# ===================================== +# Basic Configuration Tests +# ===================================== + + +def test_basic_config_creation() -> None: + """Test that a basic valid configuration can be created.""" + config = AzureAISearchConfig( + name="test_tool", + endpoint="https://test-search.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + ) + + assert config.name == "test_tool" + assert config.endpoint == "https://test-search.search.windows.net" + assert config.index_name == "test-index" + assert isinstance(config.credential, AzureKeyCredential) + assert config.query_type == "simple" # default value + + +def test_endpoint_validation() -> None: + """Test that endpoint validation works correctly.""" + valid_endpoints = ["https://test.search.windows.net", "http://localhost:8080"] + + for endpoint in valid_endpoints: + config = AzureAISearchConfig( + name="test_tool", + endpoint=endpoint, + index_name="test-index", + credential=AzureKeyCredential("test-key"), + ) + assert config.endpoint == endpoint + + invalid_endpoints = [ + "test.search.windows.net", + "ftp://test.search.windows.net", + "", + ] + + for endpoint in invalid_endpoints: + with pytest.raises(ValidationError) as exc: + AzureAISearchConfig( + name="test_tool", + endpoint=endpoint, + index_name="test-index", + credential=AzureKeyCredential("test-key"), + ) + assert "endpoint must be a valid URL" in str(exc.value) + + +def test_top_validation() -> None: + """Test validation of top parameter.""" + valid_tops = [1, 5, 10, 100] + + for top in valid_tops: + config = AzureAISearchConfig( + name="test_tool", + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + top=top, + ) + assert config.top == top + + invalid_tops = [0, -1, -10] + + for top in invalid_tops: + with pytest.raises(ValidationError) as exc: + AzureAISearchConfig( + name="test_tool", + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + top=top, + ) + assert "top must be a positive integer" in str(exc.value) + + +# ===================================== +# Query Type Tests +# ===================================== + + +def test_query_type_normalization() -> None: + """Test that query_type normalization works correctly.""" + standard_query_types = { + "simple": "simple", + "full": "full", + "semantic": "semantic", + "vector": "vector", + } + + for input_type, expected_type in standard_query_types.items(): + config_args: Dict[str, Any] = { + "name": "test_tool", + "endpoint": "https://test.search.windows.net", + "index_name": "test-index", + "credential": AzureKeyCredential("test-key"), + "query_type": cast(QueryTypeLiteral, input_type), + } + + if input_type == "semantic": + config_args["semantic_config_name"] = "my-semantic-config" + elif input_type == "vector": + config_args["vector_fields"] = ["content_vector"] + + config = AzureAISearchConfig(**config_args) + assert config.query_type == expected_type + + with pytest.raises(ValidationError) as exc: + AzureAISearchConfig( + name="test_tool", + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + query_type=cast(Any, "invalid_type"), + ) + assert "Input should be" in str(exc.value) + + +def test_semantic_config_validation() -> None: + """Test validation of semantic configuration.""" + config = AzureAISearchConfig( + name="test_tool", + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + query_type=cast(QueryTypeLiteral, "semantic"), + semantic_config_name="my-semantic-config", + ) + assert config.query_type == "semantic" + assert config.semantic_config_name == "my-semantic-config" + + with pytest.raises(ValidationError) as exc: + AzureAISearchConfig( + name="test_tool", + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + query_type=cast(QueryTypeLiteral, "semantic"), + ) + assert "semantic_config_name must be provided" in str(exc.value) + + +def test_vector_fields_validation() -> None: + """Test validation of vector fields for vector search.""" + config = AzureAISearchConfig( + name="test_tool", + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + query_type=cast(QueryTypeLiteral, "vector"), + vector_fields=["content_vector"], + ) + assert config.query_type == "vector" + assert config.vector_fields == ["content_vector"] + + +# ===================================== +# Embedding Configuration Tests +# ===================================== + + +def test_azure_openai_endpoint_validation() -> None: + """Test validation of Azure OpenAI endpoint for client-side embeddings.""" + config = AzureAISearchConfig( + name="test_tool", + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + embedding_provider="azure_openai", + embedding_model="text-embedding-ada-002", + openai_endpoint="https://test.openai.azure.com", + ) + assert config.embedding_provider == "azure_openai" + assert config.embedding_model == "text-embedding-ada-002" + assert config.openai_endpoint == "https://test.openai.azure.com" + + with pytest.raises(ValidationError) as exc: + AzureAISearchConfig( + name="test_tool", + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + embedding_provider="azure_openai", + embedding_model="text-embedding-ada-002", + ) + assert "openai_endpoint must be provided for azure_openai" in str(exc.value) + + config = AzureAISearchConfig( + name="test_tool", + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + embedding_provider="openai", + embedding_model="text-embedding-ada-002", + ) + assert config.embedding_provider == "openai" + assert config.embedding_model == "text-embedding-ada-002" + assert config.openai_endpoint is None + + +# ===================================== +# Credential and Serialization Tests +# ===================================== + + +def test_credential_validation() -> None: + """Test credential validation scenarios.""" + config = AzureAISearchConfig( + name="test_tool", + endpoint="https://test.search.windows.net", + index_name="test-index", + credential=AzureKeyCredential("test-key"), + ) + assert isinstance(config.credential, AzureKeyCredential) + assert config.credential.key == "test-key" + + if azure_sdk_available: + from azure.core.credentials import AccessToken + from azure.core.credentials_async import AsyncTokenCredential + + class TestTokenCredential(AsyncTokenCredential): + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + return AccessToken("test-token", 12345) + + async def close(self) -> None: + pass + + async def __aenter__(self) -> "TestTokenCredential": + return self + + async def __aexit__(self, *args: Any) -> None: + await self.close() + + config = AzureAISearchConfig( + name="test", + endpoint="https://endpoint", + index_name="index", + credential=TestTokenCredential(), + ) + assert isinstance(config.credential, AsyncTokenCredential) + + +def test_model_dump_scenarios() -> None: + """Test all model_dump scenarios to ensure full code coverage.""" + config = AzureAISearchConfig( + name="test", + endpoint="https://endpoint", + index_name="index", + credential=AzureKeyCredential("key"), + ) + result = config.model_dump() + assert isinstance(result["credential"], AzureKeyCredential) + assert result["credential"].key == "key" + + if azure_sdk_available: + from azure.core.credentials import AccessToken + from azure.core.credentials_async import AsyncTokenCredential + + class TestTokenCredential(AsyncTokenCredential): + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + return AccessToken("test-token", 12345) + + async def close(self) -> None: + pass + + async def __aenter__(self) -> "TestTokenCredential": + return self + + async def __aexit__(self, *args: Any) -> None: + await self.close() + + config = AzureAISearchConfig( + name="test", + endpoint="https://endpoint", + index_name="index", + credential=TestTokenCredential(), + ) + result = config.model_dump() + assert isinstance(result["credential"], AsyncTokenCredential) + else: + pytest.skip("Skipping TokenCredential test - Azure SDK not available") diff --git a/python/packages/autogen-ext/tests/tools/azure/test_ai_search_tool.py b/python/packages/autogen-ext/tests/tools/azure/test_ai_search_tool.py index 13b70e823d62..67b1d357fd50 100644 --- a/python/packages/autogen-ext/tests/tools/azure/test_ai_search_tool.py +++ b/python/packages/autogen-ext/tests/tools/azure/test_ai_search_tool.py @@ -1,1374 +1,971 @@ -"""Tests for the Azure AI Search tool.""" +"""Tests for Azure AI Search tool.""" -# pyright: reportPrivateUsage=false - -from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union, cast +import asyncio +from collections.abc import Generator +from typing import Any, Dict, List from unittest.mock import AsyncMock, MagicMock, patch import pytest from autogen_core import CancellationToken -from autogen_ext.tools.azure._ai_search import ( +from autogen_ext.tools.azure import ( + AzureAISearchConfig, AzureAISearchTool, - BaseAzureAISearchTool, - SearchQuery, SearchResult, SearchResults, - _allow_private_constructor, ) -from azure.core.credentials import AzureKeyCredential, TokenCredential -from azure.core.exceptions import HttpResponseError +from autogen_ext.tools.azure._ai_search import BaseAzureAISearchTool +from autogen_ext.tools.azure._config import DEFAULT_API_VERSION +from azure.core.credentials import AzureKeyCredential +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.exceptions import HttpResponseError, ResourceNotFoundError +from pydantic import BaseModel, Field, ValidationError + +MOCK_ENDPOINT = "https://test-search.search.windows.net" +MOCK_INDEX = "test-index" +MOCK_API_KEY = "test-key" +MOCK_CREDENTIAL = AzureKeyCredential(MOCK_API_KEY) -class MockAsyncIterator: - """Mock for async iterator to use in tests.""" +class MockAsyncTokenCredential(AsyncTokenCredential): + """Mock async token credential for testing.""" - def __init__(self, items: List[Dict[str, Any]]) -> None: - self.items = items.copy() + async def get_token(self, *scopes: str, **kwargs: Any) -> Any: + return "mock-token" - def __aiter__(self) -> "MockAsyncIterator": - return self + async def close(self) -> None: + pass - async def __anext__(self) -> Dict[str, Any]: - if not self.items: - raise StopAsyncIteration - return self.items.pop(0) + async def __aexit__(self, exc_type: Any = None, exc_val: Any = None, exc_tb: Any = None) -> None: + await self.close() @pytest.fixture -async def search_tool() -> AsyncGenerator[AzureAISearchTool, None]: - """Create a concrete search tool for testing.""" +def search_config() -> AzureAISearchConfig: + """Fixture for basic search configuration.""" + return AzureAISearchConfig( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + description="Test search tool", + ) - class ConcreteSearchTool(AzureAISearchTool): - async def _get_embedding(self, query: str) -> List[float]: - return [0.1, 0.2, 0.3] - token = _allow_private_constructor.set(True) - try: - tool = ConcreteSearchTool( - name="test-search", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=cast(TokenCredential, AzureKeyCredential("test-key")), - query_type="keyword", - search_fields=["title", "content"], - select_fields=["title", "content"], - top=10, - ) - yield tool - finally: - _allow_private_constructor.reset(token) +@pytest.fixture +def mock_search_client() -> Generator[AsyncMock, None, None]: + """Fixture for mocked search client.""" + with patch("azure.search.documents.aio.SearchClient", autospec=True) as mock: + mock_client = AsyncMock() + mock.return_value = mock_client + yield mock_client -@pytest.mark.asyncio -async def test_search_tool_run(search_tool: AsyncGenerator[AzureAISearchTool, None]) -> None: - """Test the run method of the search tool.""" - tool = await anext(search_tool) - query = "test query" - cancellation_token = CancellationToken() - - with patch.object(tool, "_get_client", AsyncMock()) as mock_client: - mock_client.return_value.search = AsyncMock( - return_value=MockAsyncIterator([{"@search.score": 0.95, "title": "Test Doc", "content": "Test Content"}]) - ) +@pytest.fixture +def mock_search_results() -> List[Dict[str, Any]]: + """Fixture for mock search results.""" + return [ + { + "id": "1", + "content": "Test content", + "@search.score": 0.8, + } + ] - results = await tool.run(query, cancellation_token) - assert isinstance(results, SearchResults) - assert len(results.results) == 1 - assert results.results[0].content["title"] == "Test Doc" - assert results.results[0].score == 0.95 +class TestSearchQuery(BaseModel): + """Test model for query validation.""" -@pytest.mark.asyncio -async def test_search_tool_error_handling(search_tool: AsyncGenerator[AzureAISearchTool, None]) -> None: - """Test error handling in the search tool.""" - tool = await anext(search_tool) - with patch.object(tool, "_get_client", AsyncMock()) as mock_client: - mock_client.return_value.search = AsyncMock(side_effect=ValueError("Test error")) - - with pytest.raises(ValueError, match="Test error"): - await tool.run("test query", CancellationToken()) + query: str = Field(min_length=1) @pytest.mark.asyncio -async def test_search_tool_cancellation(search_tool: AsyncGenerator[AzureAISearchTool, None]) -> None: - """Test cancellation of the search tool.""" - tool = await anext(search_tool) - cancellation_token = CancellationToken() - cancellation_token.cancel() +async def test_search_query_model() -> None: + """Test SearchQuery model validation.""" + query = TestSearchQuery(query="test query") + assert query.query == "test query" - with pytest.raises(ValueError, match="cancelled"): - await tool.run("test query", cancellation_token) + with pytest.raises(ValidationError): + TestSearchQuery(query="") @pytest.mark.asyncio -async def test_search_tool_vector_search() -> None: - """Test vector search functionality.""" - - class ConcreteSearchTool(AzureAISearchTool): - async def _get_embedding(self, query: str) -> List[float]: - return [0.1, 0.2, 0.3] - - token = _allow_private_constructor.set(True) - try: - tool = ConcreteSearchTool( - name="vector-search", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=cast(TokenCredential, AzureKeyCredential("test-key")), - query_type="vector", - vector_fields=["embedding"], - select_fields=["title", "content"], - top=10, - ) - - with patch.object(tool, "_get_client", AsyncMock()) as mock_client: - mock_client.return_value.search = AsyncMock( - return_value=MockAsyncIterator( - [{"@search.score": 0.95, "title": "Vector Doc", "content": "Vector Content"}] - ) - ) - - results = await tool.run("vector query", CancellationToken()) - assert len(results.results) == 1 - assert results.results[0].content["title"] == "Vector Doc" - assert results.results[0].score == 0.95 - finally: - _allow_private_constructor.reset(token) - - -class ConcreteAzureAISearchTool(AzureAISearchTool): - """Concrete implementation for testing.""" - - async def _get_embedding(self, query: str) -> List[float]: - return [0.1, 0.2, 0.3] +async def test_search_result_model() -> None: + """Test SearchResult model.""" + result = SearchResult(score=0.8, content={"title": "Test", "text": "Content"}, metadata={"@search.score": 0.8}) + assert result.score == 0.8 + assert result.content["title"] == "Test" + assert result.metadata["@search.score"] == 0.8 @pytest.mark.asyncio -async def test_create_keyword_search() -> None: - """Test the create_keyword_search factory method.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="keyword_search", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=cast(TokenCredential, AzureKeyCredential("test-key")), - search_fields=["title", "content"], - select_fields=["title", "content"], - filter="category eq 'test'", - top=5, +async def test_search_results_model() -> None: + """Test SearchResults model.""" + results = SearchResults( + results=[ + SearchResult(score=0.8, content={"title": "Test1"}, metadata={"@search.score": 0.8}), + SearchResult(score=0.6, content={"title": "Test2"}, metadata={"@search.score": 0.6}), + ] ) - - assert tool.name == "keyword_search" - assert tool.search_config.query_type == "keyword" - assert tool.search_config.filter == "category eq 'test'" + assert len(results.results) == 2 + assert results.results[0].score == 0.8 + assert results.results[1].content["title"] == "Test2" @pytest.mark.asyncio async def test_create_full_text_search() -> None: - """Test the create_full_text_search factory method.""" - tool = ConcreteAzureAISearchTool.create_full_text_search( - name="full_text_search", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=cast(TokenCredential, AzureKeyCredential("test-key")), - search_fields=["title", "content"], - select_fields=["title", "content"], - filter="category eq 'test'", - top=5, - ) + """Test creation of full text search tool.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + search_fields=["content"], + query_type="simple", + ) + assert tool.name == "test-search" + assert tool.search_config.query_type == "simple" + assert tool.search_config.search_fields == ["content"] + + with pytest.raises(ValueError, match="semantic_config_name is required"): + AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + query_type="semantic", + ) - assert tool.name == "full_text_search" - assert tool.search_config.query_type == "fulltext" - assert tool.search_config.search_fields == ["title", "content"] - assert tool.search_config.select_fields == ["title", "content"] - assert tool.search_config.filter == "category eq 'test'" - assert tool.search_config.top == 5 + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + query_type="semantic", + semantic_config_name="default", + ) + assert tool.search_config.query_type == "semantic" + assert tool.search_config.semantic_config_name == "default" @pytest.mark.asyncio async def test_create_vector_search() -> None: - """Test the create_vector_search factory method.""" - tool = ConcreteAzureAISearchTool.create_vector_search( - name="vector_search", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), + """Test creation of vector search tool.""" + tool = AzureAISearchTool.create_vector_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, vector_fields=["embedding"], - select_fields=["title", "content"], - top=5, ) - - assert tool.name == "vector_search" assert tool.search_config.query_type == "vector" assert tool.search_config.vector_fields == ["embedding"] + with pytest.raises(ValueError, match="openai_endpoint is required"): + AzureAISearchTool.create_vector_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=["embedding"], + embedding_provider="azure_openai", + embedding_model="text-embedding-ada-002", + ) + @pytest.mark.asyncio async def test_create_hybrid_search() -> None: - """Test the create_hybrid_search factory method (hybrid = text + vector, query_type will be 'fulltext' or 'semantic').""" - tool = ConcreteAzureAISearchTool.create_hybrid_search( - name="hybrid_search", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), + """Test creation of hybrid search tool.""" + tool = AzureAISearchTool.create_hybrid_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, vector_fields=["embedding"], - search_fields=["title", "content"], - select_fields=["title", "content"], - top=5, + search_fields=["content"], ) - - assert tool.name == "hybrid_search" - assert tool.search_config.query_type in ("fulltext", "semantic") assert tool.search_config.vector_fields == ["embedding"] - assert tool.search_config.search_fields == ["title", "content"] - + assert tool.search_config.search_fields == ["content"] -@pytest.mark.asyncio -async def test_run_invalid_query() -> None: - """Test the run method with an invalid query format.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_tool", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), + tool = AzureAISearchTool.create_hybrid_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=["embedding"], + search_fields=["content"], + query_type="semantic", + semantic_config_name="default", ) - - invalid_query: Dict[str, Any] = {"invalid_key": "invalid_value"} - with pytest.raises(ValueError, match="Invalid search query format"): - await tool.run(invalid_query) + assert tool.search_config.query_type == "semantic" + assert tool.search_config.semantic_config_name == "default" @pytest.mark.asyncio -async def test_process_credential_dict() -> None: - """Test the _process_credential method with a dictionary credential.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_tool", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential={"api_key": "test-key"}, - ) +async def test_search_execution(mock_search_client: AsyncMock, mock_search_results: List[Dict[str, Any]]) -> None: + """Test search execution with mocked client.""" + mock_search_client.search.return_value.__aiter__.return_value = mock_search_results - assert isinstance(tool.search_config.credential, AzureKeyCredential) - assert tool.search_config.credential.key == "test-key" - - -@pytest.mark.asyncio -async def test_run_empty_query() -> None: - """Test the run method with an empty query.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_tool", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, ) - with patch.object(tool, "_get_client", AsyncMock()): - with pytest.raises(ValueError, match="Invalid search query format"): - await tool.run("") - - -@pytest.mark.asyncio -async def test_get_client_initialization() -> None: - """Test the _get_client method for proper initialization.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_tool", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - ) - - assert tool.search_config.endpoint == "https://test.search.windows.net" - assert tool.search_config.index_name == "test-index" - - mock_client = AsyncMock() - - class MockAsyncIterator: - def __init__(self, items: List[Dict[str, Any]]) -> None: - self.items = items - - def __aiter__(self) -> "MockAsyncIterator": - return self - - async def __anext__(self) -> Dict[str, Any]: - if not self.items: - raise StopAsyncIteration - return self.items.pop(0) - - mock_client.search.return_value = MockAsyncIterator([{"@search.score": 0.9, "title": "Test Result"}]) - - with patch.object(tool, "_get_client", return_value=mock_client): - results = await tool.run("test query", CancellationToken()) - mock_client.search.assert_called_once() + with patch.object(tool, "_get_client", return_value=mock_search_client): + results = await tool.run("test query") assert len(results.results) == 1 - assert results.results[0].content["title"] == "Test Result" - assert results.results[0].score == 0.9 - - -@pytest.mark.asyncio -async def test_return_value_as_string() -> None: - """Test the return_value_as_string method for formatting search results.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_tool", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - ) - - results = SearchResults( - results=[ - SearchResult(score=0.95, content={"title": "Doc 1"}, metadata={}), - SearchResult(score=0.85, content={"title": "Doc 2"}, metadata={}), - ] - ) - - result_string = tool.return_value_as_string(results) - assert "Result 1 (Score: 0.95): title: Doc 1" in result_string - assert "Result 2 (Score: 0.85): title: Doc 2" in result_string - - -@pytest.mark.asyncio -async def test_return_value_as_string_empty() -> None: - """Test the return_value_as_string method with empty results.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_tool", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - ) + assert results.results[0].score == 0.8 + assert results.results[0].content["content"] == "Test content" - results = SearchResults(results=[]) - result_string = tool.return_value_as_string(results) - assert result_string == "No results found." + mock_search_client.search.assert_called_once() + call_kwargs = mock_search_client.search.call_args[1] + assert call_kwargs["search_text"] == "test query" @pytest.mark.asyncio -async def test_load_component() -> None: - """Test the load_component method for proper deserialization.""" - model = { - "provider": "autogen_ext.tools.azure.BaseAzureAISearchTool", - "config": { - "name": "test_tool", - "endpoint": "https://test.search.windows.net", - "index_name": "test-index", - "credential": {"api_key": "test-key"}, - "query_type": "keyword", - "search_fields": ["title", "content"], - "select_fields": ["title", "content"], - "top": 5, - }, - } - - tool = ConcreteAzureAISearchTool.load_component(model) - assert tool.name == "test_tool" - assert tool.search_config.query_type == "keyword" - assert tool.search_config.search_fields == ["title", "content"] - +async def test_search_with_caching(mock_search_client: AsyncMock, mock_search_results: List[Dict[str, Any]]) -> None: + """Test search caching functionality.""" + mock_search_client.search.return_value.__aiter__.return_value = mock_search_results -@pytest.mark.asyncio -async def test_caching_functionality() -> None: - """Test the caching functionality of the search tool.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="cache_test_tool", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, enable_caching=True, cache_ttl_seconds=300, ) - mock_client = AsyncMock() - mock_client.__aenter__.return_value = mock_client - mock_client.__aexit__.return_value = None - - test_result = {"@search.score": 0.9, "title": "Test Result"} + with patch.object(tool, "_get_client", return_value=mock_search_client): + await tool.run("test query") + assert mock_search_client.search.call_count == 1 - class MockAsyncIterator: - def __init__(self) -> None: - self.returned = False - - def __aiter__(self) -> "MockAsyncIterator": - return self - - async def __anext__(self) -> Dict[str, Any]: - if self.returned: - raise StopAsyncIteration - self.returned = True - return test_result - - mock_client.search = AsyncMock(return_value=MockAsyncIterator()) - - with patch.object(tool, "_get_client", return_value=mock_client): - results1 = await tool.run("test query") - assert len(results1.results) == 1 - assert results1.results[0].content["title"] == "Test Result" - assert mock_client.search.call_count == 1 - - mock_client.search = AsyncMock(return_value=MockAsyncIterator()) - - results2 = await tool.run("test query") - assert len(results2.results) == 1 - assert results2.results[0].content["title"] == "Test Result" - assert mock_client.search.call_count == 1 + await tool.run("test query") + assert mock_search_client.search.call_count == 1 @pytest.mark.asyncio -async def test_semantic_configuration_name_handling() -> None: - """Test handling of semantic configuration names in fulltext search.""" - tool = ConcreteAzureAISearchTool.create_full_text_search( - name="semantic_config_test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - search_fields=["title", "content"], - select_fields=["title", "content"], +async def test_error_handling(mock_search_client: AsyncMock) -> None: + """Test error handling in search execution.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, ) - mock_client = AsyncMock() - mock_client.search.return_value = MockAsyncIterator([{"@search.score": 0.9, "title": "Semantic Test Result"}]) - - assert tool.search_config.query_type == "fulltext" - assert tool.search_config.search_fields == ["title", "content"] - - with patch.object(tool, "_get_client", return_value=mock_client): - mock_run = AsyncMock() - mock_run.return_value = SearchResults( - results=[SearchResult(score=0.9, content={"title": "Semantic Test Result"}, metadata={})] - ) - - with patch.object(tool, "run", mock_run): - results = await tool.run("semantic query") - mock_run.assert_called_once() - - assert len(results.results) == 1 - assert results.results[0].content["title"] == "Semantic Test Result" - - -@pytest.mark.asyncio -async def test_http_response_error_handling() -> None: - """Test handling of different HTTP response errors.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_search", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - ) - - mock_client = AsyncMock() - http_error = HttpResponseError() - http_error.message = "401 Unauthorized: Access is denied due to invalid credentials" + with patch.object(tool, "_get_client", return_value=mock_search_client): + mock_search_client.search.side_effect = ResourceNotFoundError("Index not found") + with pytest.raises(ValueError, match="Index.*not found"): + await tool.run("test query") - with patch.object(tool, "_get_client", return_value=mock_client): - mock_client.search = AsyncMock(side_effect=http_error) + mock_search_client.search.side_effect = HttpResponseError(status_code=401, message="Unauthorized") with pytest.raises(ValueError, match="Authentication failed"): await tool.run("test query") - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_search", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("invalid-key"), - ) - - with patch.object(tool, "_get_client", AsyncMock(side_effect=ValueError("Invalid key"))): - with pytest.raises(ValueError, match="Authentication failed"): + mock_search_client.search.side_effect = HttpResponseError(status_code=500, message="Internal server error") + with pytest.raises(ValueError, match="Error from Azure AI Search"): await tool.run("test query") @pytest.mark.asyncio -async def test_run_with_search_query_object() -> None: - """Test running the search with a SearchQuery object instead of a string.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_tool", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - ) +async def test_embedding_provider_mixin() -> None: + """Test the embedding provider functionality.""" + with patch("openai.AsyncAzureOpenAI") as mock_azure_openai: + mock_client = AsyncMock() + mock_azure_openai.return_value = mock_client + mock_client.embeddings.create.return_value.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] - mock_client = AsyncMock() - mock_client.search.return_value = MockAsyncIterator([{"@search.score": 0.85, "title": "Query Object Test"}]) + tool = AzureAISearchTool.create_vector_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=["embedding"], + embedding_provider="azure_openai", + embedding_model="text-embedding-ada-002", + openai_endpoint="https://test.openai.azure.com", + openai_api_key="test-key", + ) - with patch.object(tool, "_get_client", return_value=mock_client): - search_query = SearchQuery(query="advanced query") - results = await tool.run(search_query) + embedding = await tool._get_embedding("test query") # pyright: ignore[reportPrivateUsage] + assert len(embedding) == 3 + assert embedding == [0.1, 0.2, 0.3] - assert len(results.results) == 1 - assert results.results[0].content["title"] == "Query Object Test" - mock_client.search.assert_called_once() + mock_client.embeddings.create.assert_called_once_with(model="text-embedding-ada-002", input="test query") @pytest.mark.asyncio -async def test_dict_document_processing() -> None: - """Test processing of document with dict-like interface.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_tool", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), +async def test_credential_processing() -> None: + """Test credential processing logic.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, ) + assert isinstance(tool.search_config.credential, AzureKeyCredential) + assert tool.search_config.credential.key == MOCK_API_KEY - class DictLikeDoc: - def __init__(self, data: Dict[str, Any]) -> None: - self._data = data - - def items(self) -> List[tuple[str, Any]]: - return list(self._data.items()) - - mock_client = AsyncMock() - - class SpecialMockAsyncIterator: - def __init__(self) -> None: - self.returned = False - - def __aiter__(self) -> "SpecialMockAsyncIterator": - return self - - async def __anext__(self) -> DictLikeDoc: - if self.returned: - raise StopAsyncIteration - self.returned = True - return DictLikeDoc({"@search.score": 0.75, "title": "Dict Like Doc"}) - - mock_client.search.return_value = SpecialMockAsyncIterator() - - with patch.object(tool, "_get_client", return_value=mock_client): - results = await tool.run("test query") - - assert len(results.results) == 1 - assert results.results[0].content["title"] == "Dict Like Doc" - assert results.results[0].score == 0.75 - - -@pytest.mark.asyncio -async def test_document_processing_error_handling() -> None: - """Test error handling during document processing.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_tool", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), + mock_async_credential = MockAsyncTokenCredential() + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=mock_async_credential, ) + assert isinstance(tool.search_config.credential, AsyncTokenCredential) - mock_client = AsyncMock() - - class ProblemDoc: - def items(self) -> None: - raise AttributeError("Simulated error in document processing") - - class MixedResultsAsyncIterator: - def __init__(self) -> None: - self.docs: List[Union[Dict[str, Any], ProblemDoc]] = [ - {"@search.score": 0.9, "title": "Good Doc"}, - ProblemDoc(), - {"@search.score": 0.8, "title": "Another Good Doc"}, - ] - self.index = 0 - - def __aiter__(self) -> "MixedResultsAsyncIterator": - return self - - async def __anext__(self) -> Union[Dict[str, Any], ProblemDoc]: - if self.index >= len(self.docs): - raise StopAsyncIteration - doc = self.docs[self.index] - self.index += 1 - return doc - - mock_client.search.return_value = MixedResultsAsyncIterator() - - with patch.object(tool, "_get_client", return_value=mock_client): - results = await tool.run("test query") - - assert len(results.results) == 2 - assert results.results[0].content["title"] == "Good Doc" - assert results.results[1].content["title"] == "Another Good Doc" + with pytest.raises(ValueError, match="Invalid configuration"): + AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential={"api_key": "test-key"}, # type: ignore + ) @pytest.mark.asyncio -async def test_index_not_found_error() -> None: - """Test handling of 'index not found' error.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_tool", - endpoint="https://test.search.windows.net", - index_name="nonexistent-index", - credential=AzureKeyCredential("test-key"), +async def test_return_value_as_string() -> None: + """Test the string representation of search results.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, ) - not_found_error = ValueError("The index 'nonexistent-index' was not found") - - with patch.object(tool, "_get_client", AsyncMock(side_effect=not_found_error)): - with pytest.raises(ValueError, match="Index 'nonexistent-index' not found"): - await tool.run("test query") - - -@pytest.mark.asyncio -async def test_http_response_with_500_error() -> None: - """Test handling of HTTP 500 error responses.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_search", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), + results = SearchResults( + results=[ + SearchResult(score=0.8, content={"title": "Test1", "text": "Content1"}, metadata={"@search.score": 0.8}), + SearchResult(score=0.6, content={"title": "Test2", "text": "Content2"}, metadata={"@search.score": 0.6}), + ] ) + result_str = tool.return_value_as_string(results) + assert "Result 1 (Score: 0.80)" in result_str + assert "Result 2 (Score: 0.60)" in result_str + assert "Test1" in result_str + assert "Content2" in result_str - http_error = HttpResponseError() - http_error.message = "500 Internal Server Error: Something went wrong on the server" - - with patch.object(tool, "_get_client", AsyncMock()) as mock_client: - mock_client.return_value.search = AsyncMock(side_effect=http_error) - - with pytest.raises(ValueError, match="Error from Azure AI Search"): - await tool.run("test query") + empty_results = SearchResults(results=[]) + assert tool.return_value_as_string(empty_results) == "No results found." @pytest.mark.asyncio -async def test_cancellation_during_search() -> None: - """Test cancellation token functionality during the search process.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_tool", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), +async def test_schema() -> None: + """Test tool schema generation.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, ) - cancellation_token = CancellationToken() - cancellation_token.cancel() - - with pytest.raises(ValueError, match="Operation cancelled"): - await tool.run("test query", cancellation_token) + schema = tool.schema + assert schema["name"] == "test-search" + assert "description" in schema + assert "parameters" in schema + assert "required" in schema["parameters"] + assert schema["parameters"]["type"] == "object" + assert "query" in schema["parameters"]["properties"] + assert schema["parameters"]["required"] == ["query"] @pytest.mark.asyncio -async def test_run_with_dict_query_format() -> None: - """Test running the search with a dictionary query format with 'query' key.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_tool", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - ) - - mock_client = AsyncMock() - mock_client.search.return_value = MockAsyncIterator([{"@search.score": 0.85, "title": "Dict Query Test"}]) +async def test_vector_search_execution( + mock_search_client: AsyncMock, mock_search_results: List[Dict[str, Any]] +) -> None: + """Test vector search execution.""" + mock_search_client.search.return_value.__aiter__.return_value = mock_search_results - with patch.object(tool, "_get_client", return_value=mock_client): - query_dict = {"query": "dict style query"} - results = await tool.run(query_dict) - - assert len(results.results) == 1 - assert results.results[0].content["title"] == "Dict Query Test" - mock_client.search.assert_called_once() - - -@pytest.mark.asyncio -async def test_object_based_document_processing() -> None: - """Test processing of document with object attributes instead of dict interface.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="test_tool", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), + tool = AzureAISearchTool.create_vector_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=["embedding"], + embedding_provider="azure_openai", + embedding_model="text-embedding-ada-002", + openai_endpoint="https://test.openai.azure.com", + openai_api_key="test-key", ) - class ObjectDoc: - """Test document class with object attributes.""" - - def __init__(self) -> None: - self.title = "Object Doc" - self.content = "Object content" - self._private_attr = "private" - self.__search_score = 0.8 - - mock_client = AsyncMock() - - class ObjectDocAsyncIterator: - def __init__(self) -> None: - self.returned = False - - def __aiter__(self) -> "ObjectDocAsyncIterator": - return self - - async def __anext__(self) -> ObjectDoc: - if self.returned: - raise StopAsyncIteration - self.returned = True - return ObjectDoc() - - mock_client.search.return_value = ObjectDocAsyncIterator() - - with patch.object(tool, "_get_client", return_value=mock_client): + mock_embedding = [0.1, 0.2, 0.3] + with ( + patch.object(tool, "_get_embedding", return_value=mock_embedding), + patch.object(tool, "_get_client", return_value=mock_search_client), + ): results = await tool.run("test query") - assert len(results.results) == 1 - assert results.results[0].content["title"] == "Object Doc" - assert results.results[0].content["content"] == "Object content" - assert "_private_attr" not in results.results[0].content + mock_search_client.search.assert_called_once() @pytest.mark.asyncio -async def test_vector_search_with_provided_vectors() -> None: - """Test vector search using vectors provided directly in the search options.""" - tool = ConcreteAzureAISearchTool.create_vector_search( - name="vector_direct_search", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - vector_fields=["embedding"], - select_fields=["title", "content"], +async def test_cancellation() -> None: + """Test search cancellation.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, ) - mock_client = AsyncMock() - mock_client.search.return_value = MockAsyncIterator([{"@search.score": 0.95, "title": "Vector Direct Test"}]) - - query = "test vector search" - - with patch.object(tool, "_get_client", return_value=mock_client): - results = await tool.run(query) - assert len(results.results) == 1 - assert results.results[0].content["title"] == "Vector Direct Test" - - mock_client.search.assert_called_once() + token = CancellationToken() + token.cancel() + with pytest.raises(asyncio.CancelledError): + await tool.run("test query", cancellation_token=token) @pytest.mark.asyncio -async def test_credential_token_expiry_handling() -> None: - """Test handling credential token expiry and error cases.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="token_test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), +async def test_invalid_query_format() -> None: + """Test invalid query format handling.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, ) - auth_error = HttpResponseError() - auth_error.message = "401 Unauthorized: Access token has expired or is not yet valid" - - with patch.object(tool, "_get_client", AsyncMock()) as mock_client: - mock_client.return_value.search = AsyncMock(side_effect=auth_error) - - with pytest.raises(ValueError, match="Authentication failed"): - await tool.run("test query") - - token_error = ValueError("401 Unauthorized: Token is invalid") - - with patch.object(tool, "_get_client", AsyncMock(side_effect=token_error)): - with pytest.raises(ValueError, match="Authentication failed"): - await tool.run("test query") + with pytest.raises(ValueError, match="Invalid search query format"): + await tool.run({"invalid": "format"}) @pytest.mark.asyncio -async def test_search_with_user_provided_vectors() -> None: - """Test the use of user-provided embedding vectors in SearchQuery.""" - tool = ConcreteAzureAISearchTool.create_vector_search( - name="vector_test_with_embeddings", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - vector_fields=["embedding"], +async def test_client_cleanup() -> None: + """Test client cleanup.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, ) mock_client = AsyncMock() - mock_client.search.return_value = MockAsyncIterator([{"@search.score": 0.95, "title": "Vector Result"}]) - - custom_vectors = [0.1, 0.2, 0.3, 0.4, 0.5] - query_dict = {"query": "test query", "vectors": {"embedding": custom_vectors}} - - with patch.object(tool, "_get_client", return_value=mock_client): - results = await tool.run(query_dict) - - assert len(results.results) == 1 - assert results.results[0].content["title"] == "Vector Result" - - mock_client.search.assert_called_once() - - -@pytest.mark.asyncio -async def test_component_loading_with_invalid_params() -> None: - """Test loading components with invalid parameters.""" - - class OtherClass: - pass - - with pytest.raises(TypeError, match="Cannot create instance"): - BaseAzureAISearchTool.load_component( - {"provider": "autogen_ext.tools.azure.BaseAzureAISearchTool", "config": {}}, - expected=OtherClass, # type: ignore - ) - - with pytest.raises(Exception) as excinfo: - ConcreteAzureAISearchTool.load_component("not a dict or ComponentModel") # type: ignore - error_msg = str(excinfo.value).lower() - assert any(text in error_msg for text in ["attribute", "type", "object", "dict", "str"]) + tool._client = mock_client # pyright: ignore[reportPrivateUsage] + await tool.close() - with pytest.raises(Exception) as excinfo: - ConcreteAzureAISearchTool.load_component({}) - error_msg = str(excinfo.value).lower() - assert any(text in error_msg for text in ["validation", "required", "missing", "field"]) + mock_client.close.assert_called_once() + assert tool._client is None # pyright: ignore[reportPrivateUsage] @pytest.mark.asyncio -async def test_factory_method_validation() -> None: - """Test validation in factory methods.""" - with pytest.raises(ValueError, match="endpoint must be a valid URL"): - ConcreteAzureAISearchTool.create_keyword_search( - name="test", endpoint="", index_name="test-index", credential=AzureKeyCredential("test-key") +async def test_config_validation() -> None: + """Test configuration validation.""" + with pytest.raises(ValueError, match="vector_fields must contain at least one field"): + AzureAISearchTool._validate_config( # pyright: ignore[reportPrivateUsage] + { + "name": "test-search", + "endpoint": MOCK_ENDPOINT, + "index_name": MOCK_INDEX, + "credential": MOCK_CREDENTIAL, + }, + "vector", ) - with pytest.raises(ValueError, match="endpoint must be a valid URL"): - ConcreteAzureAISearchTool.create_keyword_search( - name="test", endpoint="invalid-url", index_name="test-index", credential=AzureKeyCredential("test-key") + with pytest.raises(ValueError, match="vector_fields must contain at least one field"): + AzureAISearchTool._validate_config( # pyright: ignore[reportPrivateUsage] + { + "name": "test-search", + "endpoint": MOCK_ENDPOINT, + "index_name": MOCK_INDEX, + "credential": MOCK_CREDENTIAL, + "search_fields": ["content"], + }, + "hybrid", ) - with pytest.raises(ValueError, match="index_name cannot be empty"): - ConcreteAzureAISearchTool.create_keyword_search( - name="test", - endpoint="https://test.search.windows.net", - index_name="", - credential=AzureKeyCredential("test-key"), - ) - - with pytest.raises(ValueError, match="name cannot be empty"): - ConcreteAzureAISearchTool.create_keyword_search( - name="", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - ) - with pytest.raises(ValueError, match="credential cannot be None"): - ConcreteAzureAISearchTool.create_keyword_search( - name="test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=None, # type: ignore - ) +@pytest.mark.asyncio +async def test_openai_embedding_provider() -> None: + """Test OpenAI embedding provider.""" + with patch("openai.AsyncOpenAI") as mock_openai: + mock_client = AsyncMock() + mock_openai.return_value = mock_client + mock_client.embeddings.create.return_value.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] - with pytest.raises(ValueError, match="vector_fields must contain at least one field name"): - ConcreteAzureAISearchTool.create_hybrid_search( - name="test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - vector_fields=[], + tool = AzureAISearchTool.create_vector_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=["embedding"], + embedding_provider="openai", + embedding_model="text-embedding-ada-002", + openai_api_key="test-key", ) - with pytest.raises(ValueError, match="vector_fields must contain at least one field name"): - ConcreteAzureAISearchTool.create_hybrid_search( - name="test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - vector_fields=[], - ) + embedding = await tool._get_embedding("test query") # pyright: ignore[reportPrivateUsage] + assert len(embedding) == 3 + assert embedding == [0.1, 0.2, 0.3] @pytest.mark.asyncio -async def test_direct_tool_initialization_error() -> None: - """Test that directly initializing AzureAISearchTool raises an error.""" - - class TestSearchTool(AzureAISearchTool): - async def _get_embedding(self, query: str) -> List[float]: - return [0.1, 0.2, 0.3] - - with pytest.raises(RuntimeError, match="Constructor is private"): - TestSearchTool( - name="test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - query_type="keyword", +async def test_embedding_provider_error_handling() -> None: + """Test error handling in embedding providers.""" + with pytest.raises(ValueError, match="openai_endpoint is required when embedding_provider is 'azure_openai'"): + AzureAISearchTool.create_vector_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=["embedding"], + embedding_provider="azure_openai", + embedding_model="text-embedding-ada-002", + openai_api_version="2023-11-01", ) + tool = AzureAISearchTool.create_vector_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=["embedding"], + embedding_provider="azure_openai", + embedding_model="text-embedding-ada-002", + openai_endpoint="https://test.openai.azure.com", + openai_api_version="2023-11-01", + ) + with patch("azure.identity.DefaultAzureCredential") as mock_credential: + mock_credential.return_value.get_token.return_value = None + with pytest.raises(ValueError, match="Failed to acquire token using DefaultAzureCredential for Azure OpenAI"): + await tool._get_embedding("test query") # pyright: ignore[reportPrivateUsage] + + tool = AzureAISearchTool.create_vector_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=["embedding"], + embedding_provider="unsupported_provider", + embedding_model="test-model", + ) + with pytest.raises(ValueError, match="Unsupported client-side embedding provider: unsupported_provider"): + await tool._get_embedding("test query") # pyright: ignore[reportPrivateUsage] + @pytest.mark.asyncio -async def test_credential_dict_with_missing_api_key() -> None: - """Test handling of credential dict without api_key.""" - with pytest.raises(ValueError, match="If credential is a dict, it must contain an 'api_key' key"): - ConcreteAzureAISearchTool.create_keyword_search( - name="test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential={"invalid_key": "value"}, +async def test_abstract_base_class() -> None: + """Test abstract base class behavior.""" + with pytest.raises(NotImplementedError): + BaseAzureAISearchTool._from_config( # pyright: ignore[reportPrivateUsage] + AzureAISearchConfig( + name="test", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + ) ) @pytest.mark.asyncio -async def test_complex_error_handling_scenarios() -> None: - """Test more complex error handling scenarios.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="error_test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), +async def test_client_initialization_errors() -> None: + """Test client initialization error handling.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, ) - permission_error = HttpResponseError() - permission_error.message = "403 Forbidden: Access is denied" - - with patch.object(tool, "_get_client", AsyncMock(side_effect=permission_error)): - with pytest.raises(ValueError, match="Error from Azure AI Search"): - await tool.run("test query") - - unexpected_error = Exception("Unexpected error during initialization") + with patch("azure.search.documents.aio.SearchClient.__init__", side_effect=Exception("Connection error")): + with pytest.raises(ValueError, match="Unexpected error initializing search client: Connection error"): + await tool._get_client() # pyright: ignore[reportPrivateUsage] - with patch.object(tool, "_get_client", AsyncMock(side_effect=unexpected_error)): - with pytest.raises(ValueError, match="Error from Azure AI Search"): - await tool.run("test query") + with patch( + "azure.search.documents.aio.SearchClient.__init__", side_effect=ResourceNotFoundError("Index not found") + ): + with pytest.raises(ValueError, match=f"Index '{MOCK_INDEX}' not found"): + await tool._get_client() # pyright: ignore[reportPrivateUsage] @pytest.mark.asyncio -async def test_multi_step_vector_search() -> None: - """Test a multi-step vector search with query embeddings and explicit search options.""" - tool = ConcreteAzureAISearchTool.create_vector_search( - name="vector_multi_step", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - vector_fields=["embedding"], +async def test_client_initialization_with_error() -> None: + """Test client initialization with various errors.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, ) - mock_client = AsyncMock() - mock_client.search.return_value = MockAsyncIterator([{"@search.score": 0.98, "title": "Vector Embedding Test"}]) + class MockResponse: + def __init__(self, status_code: int, reason: str): + self.status_code = status_code + self.reason = reason + self.request = object() - embedding = [0.1, 0.2, 0.3, 0.4, 0.5] - with patch.object(tool, "_get_embedding", AsyncMock(return_value=embedding)): - with patch.object(tool, "_get_client", return_value=mock_client): - results = await tool.run("vector embedding query") + def text(self) -> str: + return f"{self.status_code} {self.reason}" - assert len(results.results) == 1 - assert results.results[0].content["title"] == "Vector Embedding Test" + mock_response = MockResponse(status_code=401, reason="Unauthorized") + with patch( + "azure.search.documents.aio.SearchClient.__init__", side_effect=HttpResponseError(response=mock_response) + ): + with pytest.raises(ValueError, match="Authentication failed"): + await tool._get_client() # pyright: ignore[reportPrivateUsage] - mock_client.search.assert_called_once() + mock_response = MockResponse(status_code=403, reason="Forbidden") + with patch( + "azure.search.documents.aio.SearchClient.__init__", side_effect=HttpResponseError(response=mock_response) + ): + with pytest.raises(ValueError, match="Permission denied"): + await tool._get_client() # pyright: ignore[reportPrivateUsage] - _, kwargs = mock_client.search.call_args - assert "vector_queries" in kwargs + mock_response = MockResponse(status_code=500, reason="Internal Server Error") + with patch( + "azure.search.documents.aio.SearchClient.__init__", side_effect=HttpResponseError(response=mock_response) + ): + with pytest.raises(ValueError, match="Error connecting to Azure AI Search"): + await tool._get_client() # pyright: ignore[reportPrivateUsage] @pytest.mark.asyncio -async def test_error_handling_in_special_cases() -> None: - """Test error handling for specific error cases that might be missed.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="error_test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - ) +async def test_search_document_processing_error(mock_search_client: AsyncMock) -> None: + """Test error handling during search document processing.""" + mock_search_client.search.return_value.__aiter__.return_value = [{"invalid": "document", "@search.score": None}] - not_found_error = ValueError("The requested resource with 'test-index' was not found") - - with patch.object(tool, "_get_client", AsyncMock(side_effect=not_found_error)): - with pytest.raises(ValueError, match="Index 'test-index' not found"): - await tool.run("error query") - - auth_error = ValueError("401 Unauthorized error occurred") + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + ) - with patch.object(tool, "_get_client", AsyncMock(side_effect=auth_error)): - with pytest.raises(ValueError, match="Authentication failed"): - await tool.run("auth error query") + with patch.object(tool, "_get_client", return_value=mock_search_client): + results = await tool.run("test query") + assert len(results.results) == 0 @pytest.mark.asyncio -async def test_component_loading_with_config_model() -> None: - """Test the load_component method with a ComponentModel instead of dict.""" - from autogen_core import ComponentModel - - model = ComponentModel( - provider="autogen_ext.tools.azure.BaseAzureAISearchTool", - config={ - "name": "model_test", - "endpoint": "https://test.search.windows.net", - "index_name": "test-index", - "credential": {"api_key": "test-key"}, - "query_type": "keyword", - "search_fields": ["title", "content"], - }, +async def test_search_with_expired_cache( + mock_search_client: AsyncMock, mock_search_results: List[Dict[str, Any]] +) -> None: + """Test search with expired cache.""" + mock_search_client.search.return_value.__aiter__.return_value = mock_search_results + + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + enable_caching=True, + cache_ttl_seconds=1, ) - with patch.object(ConcreteAzureAISearchTool, "create_keyword_search") as mock_create: - mock_create.return_value = ConcreteAzureAISearchTool.create_keyword_search( - name="model_test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - ) + with patch.object(tool, "_get_client", return_value=mock_search_client): + await tool.run("test query") + assert mock_search_client.search.call_count == 1 - tool = ConcreteAzureAISearchTool.load_component(model) + await asyncio.sleep(1.1) - assert tool.name == "model_test" + await tool.run("test query") + assert mock_search_client.search.call_count == 2 @pytest.mark.asyncio -async def test_fallback_vectorizable_text_query() -> None: - """Test the fallback VectorizableTextQuery class when Azure SDK is not available.""" - - class MockVectorizableTextQuery: - def __init__(self, text: str, k_nearest_neighbors: int, fields: str) -> None: - self.text = text - self.k_nearest_neighbors = k_nearest_neighbors - self.fields = fields - - query1 = MockVectorizableTextQuery(text="test query", k_nearest_neighbors=5, fields="title") - assert query1.text == "test query" - assert query1.fields == "title" - - query2 = MockVectorizableTextQuery(text="test query", k_nearest_neighbors=3, fields="title,content") - assert query2.text == "test query" - assert query2.fields == "title,content" +async def test_search_with_invalid_credential() -> None: + """Test search with invalid credential format.""" + with pytest.raises(ValueError, match="Invalid configuration"): + AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential={"api_key": "test-key"}, # type: ignore + ) @pytest.mark.asyncio -async def test_dump_component() -> None: - """Test the dump_component method.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="dump_test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), +async def test_search_with_empty_query() -> None: + """Test search with empty query.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, ) - component_model = tool.dump_component() - assert component_model.provider == "autogen_ext.tools.azure.BaseAzureAISearchTool" - assert component_model.config["name"] == "dump_test" - assert component_model.config["endpoint"] == "https://test.search.windows.net" - assert component_model.config["index_name"] == "test-index" + with pytest.raises(ValueError, match="Search query cannot be empty"): + await tool.run("") @pytest.mark.asyncio -async def test_fallback_config_class() -> None: - """Test the fallback configuration class.""" - from autogen_ext.tools.azure._ai_search import _FallbackAzureAISearchConfig # pyright: ignore[reportPrivateUsage] - - config = _FallbackAzureAISearchConfig( - name="fallback_test", - endpoint="https://test.search.windows.net", - index_name="test-index", - query_type="vector", +async def test_vector_search_without_query() -> None: + """Test vector search with empty query.""" + tool = AzureAISearchTool.create_vector_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, vector_fields=["embedding"], - top=10, ) - assert config.name == "fallback_test" - assert config.endpoint == "https://test.search.windows.net" - assert config.index_name == "test-index" - assert config.query_type == "vector" - assert config.vector_fields == ["embedding"] - assert config.top == 10 + with pytest.raises(ValueError, match="Search query cannot be empty"): + await tool.run("") @pytest.mark.asyncio -async def test_search_with_different_query_types() -> None: - """Test search with different query types and parameters.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="query_types_test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), +async def test_search_with_cancellation_token_already_cancelled() -> None: + """Test search with already cancelled token.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, ) - mock_client = AsyncMock() - mock_client.search.return_value = MockAsyncIterator([{"@search.score": 0.9, "title": "Test Result"}]) - - with patch.object(tool, "_get_client", return_value=mock_client): - await tool.run("string query") - mock_client.search.assert_called_once() - mock_client.search.reset_mock() - - await tool.run({"query": "dict query"}) - mock_client.search.assert_called_once() - mock_client.search.reset_mock() + token = CancellationToken() + token.cancel() - await tool.run(SearchQuery(query="object query")) - mock_client.search.assert_called_once() - - -class MockEmbeddingData: - """Mock for OpenAI embedding data.""" - - def __init__(self, embedding: List[float]): - self.embedding = embedding - - -class MockEmbeddingResponse: - """Mock for OpenAI embedding response.""" - - def __init__(self, data: List[MockEmbeddingData]): - self.data = data + with pytest.raises(asyncio.CancelledError): + await tool.run("test query", cancellation_token=token) @pytest.mark.asyncio -async def test_get_embedding_methods() -> None: - """Test the _get_embedding method with different providers.""" - - class TestSearchTool(AzureAISearchTool): - async def _get_embedding(self, query: str) -> List[float]: - return [0.1, 0.2, 0.3] +async def test_config_validation_edge_cases() -> None: + """Test configuration validation edge cases.""" + with pytest.raises( + ValueError, + match="Invalid configuration: 1 validation error for AzureAISearchConfig\n Value error, vector_fields must be provided for vector search", + ): + AzureAISearchTool.create_vector_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=[], + ) - with patch.object(AzureAISearchTool, "_get_embedding", autospec=True) as mock_get_embedding: - mock_get_embedding.return_value = [0.1, 0.2, 0.3] + with pytest.raises(ValueError, match="vector_fields must contain at least one field name for hybrid search"): + AzureAISearchTool.create_hybrid_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=[], + search_fields=["content"], + ) - tool = TestSearchTool.create_vector_search( - name="test_vector_search", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), + with pytest.raises(ValueError, match="semantic_config_name is required when query_type is 'semantic'"): + AzureAISearchTool.create_hybrid_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, vector_fields=["embedding"], + search_fields=["content"], + query_type="semantic", ) - result = await AzureAISearchTool._get_embedding(tool, "test query") # pyright: ignore[reportPrivateUsage] - assert result == [0.1, 0.2, 0.3] - mock_get_embedding.assert_called_once_with(tool, "test query") - @pytest.mark.asyncio -async def test_get_embedding_azure_openai_path() -> None: - """Test the Azure OpenAI path in _get_embedding.""" - mock_azure_openai = AsyncMock() - mock_azure_openai.embeddings.create.return_value = MagicMock(data=[MagicMock(embedding=[0.1, 0.2, 0.3])]) +async def test_base_class_functionality() -> None: + """Test base class functionality.""" + config = AzureAISearchConfig( + name="test", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + ) - with ( - patch("openai.AsyncAzureOpenAI", return_value=mock_azure_openai), - patch("azure.identity.DefaultAzureCredential"), - patch("autogen_ext.tools.azure._ai_search.getattr") as mock_getattr, - ): + with pytest.raises(NotImplementedError, match="BaseAzureAISearchTool.*cannot be instantiated directly"): + BaseAzureAISearchTool._from_config(config) # pyright: ignore[reportPrivateUsage] - def side_effect(obj: Any, name: str, default: Any = None) -> Any: - if name == "embedding_provider": - return "azure_openai" - elif name == "embedding_model": - return "text-embedding-ada-002" - elif name == "openai_endpoint": - return "https://test.openai.azure.com" - elif name == "openai_api_key": - return "test-key" - return default - - mock_getattr.side_effect = side_effect - - class TestTool(AzureAISearchTool): - async def _get_embedding(self, query: str) -> List[float]: - return await AzureAISearchTool._get_embedding(self, query) - - token = _allow_private_constructor.set(True) - try: - tool = TestTool( - name="test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - query_type="vector", - vector_fields=["embedding"], - ) + class TestSearchTool(BaseAzureAISearchTool): + async def _get_embedding(self, query: str) -> List[float]: + return [0.1, 0.2, 0.3] - result = await tool._get_embedding("test query") # pyright: ignore[reportPrivateUsage] - assert result == [0.1, 0.2, 0.3] - mock_azure_openai.embeddings.create.assert_called_once_with( - model="text-embedding-ada-002", input="test query" + @classmethod + def _from_config(cls, config: AzureAISearchConfig) -> "TestSearchTool": + return cls( + name=config.name, + endpoint=config.endpoint, + index_name=config.index_name, + credential=config.credential, ) - finally: - _allow_private_constructor.reset(token) + + tool = TestSearchTool._from_config(config) # pyright: ignore[reportPrivateUsage] + assert tool.name == "test" + assert tool.search_config.endpoint == MOCK_ENDPOINT @pytest.mark.asyncio -async def test_get_embedding_openai_path() -> None: - """Test the OpenAI path in _get_embedding.""" - mock_openai = AsyncMock() - mock_openai.embeddings.create.return_value = MagicMock(data=[MagicMock(embedding=[0.4, 0.5, 0.6])]) +async def test_client_cleanup_edge_cases() -> None: + """Test client cleanup edge cases.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + ) - with ( - patch("openai.AsyncOpenAI", return_value=mock_openai), - patch("autogen_ext.tools.azure._ai_search.getattr") as mock_getattr, - ): + tool._client = None # pyright: ignore[reportPrivateUsage] + await tool.close() - def side_effect(obj: Any, name: str, default: Any = None) -> Any: - if name == "embedding_provider": - return "openai" - elif name == "embedding_model": - return "text-embedding-3-small" - elif name == "openai_api_key": - return "test-key" - return default + mock_client = AsyncMock() + mock_client.close.side_effect = Exception("Failed to close") + tool._client = mock_client # pyright: ignore[reportPrivateUsage] + await tool.close() + assert tool._client is None # pyright: ignore[reportPrivateUsage] - mock_getattr.side_effect = side_effect - class TestTool(AzureAISearchTool): - async def _get_embedding(self, query: str) -> List[float]: - return await AzureAISearchTool._get_embedding(self, query) +@pytest.mark.asyncio +async def test_token_acquisition_edge_cases() -> None: + """Test token acquisition edge cases.""" + with patch("azure.identity.DefaultAzureCredential") as mock_credential: + mock_credential.return_value.get_token.return_value = None - token = _allow_private_constructor.set(True) - try: - tool = TestTool( - name="test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - query_type="vector", - vector_fields=["embedding"], - ) + tool = AzureAISearchTool.create_vector_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=["embedding"], + embedding_provider="azure_openai", + embedding_model="text-embedding-ada-002", + openai_endpoint="https://test.openai.azure.com", + openai_api_version="2023-11-01", + ) - result = await tool._get_embedding("test query") # pyright: ignore[reportPrivateUsage] - assert result == [0.4, 0.5, 0.6] - mock_openai.embeddings.create.assert_called_once_with(model="text-embedding-3-small", input="test query") - finally: - _allow_private_constructor.reset(token) + with pytest.raises(ValueError, match="Failed to acquire token using DefaultAzureCredential for Azure OpenAI"): + await tool._get_embedding("test query") # pyright: ignore[reportPrivateUsage] @pytest.mark.asyncio -async def test_get_embedding_error_cases_direct() -> None: - """Test error cases in the _get_embedding method.""" - - class DirectEmbeddingTool(AzureAISearchTool): - async def _get_embedding(self, query: str) -> List[float]: - return await super()._get_embedding(query) - - token = _allow_private_constructor.set(True) - try: - tool = DirectEmbeddingTool( - name="error_embedding_test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - query_type="vector", +async def test_hybrid_search_validation() -> None: + """Test hybrid search validation edge cases.""" + with pytest.raises(ValueError, match="semantic_config_name is required when query_type is 'semantic'"): + AzureAISearchTool.create_hybrid_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, vector_fields=["embedding"], + search_fields=["content"], + query_type="semantic", ) - with pytest.raises( - ValueError, match="To use vector search, you must provide embedding_provider and embedding_model" - ): - await tool._get_embedding("test query") - - tool.search_config.embedding_provider = "azure_openai" - with pytest.raises( - ValueError, match="To use vector search, you must provide embedding_provider and embedding_model" - ): - await tool._get_embedding("test query") + with pytest.raises(ValueError, match="openai_endpoint is required when embedding_provider is 'azure_openai'"): + AzureAISearchTool.create_hybrid_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=["embedding"], + search_fields=["content"], + embedding_provider="azure_openai", + embedding_model="text-embedding-ada-002", + ) - tool.search_config.embedding_model = "text-embedding-ada-002" - def missing_endpoint_side_effect(obj: Any, name: str, default: Any = None) -> Any: - if name == "openai_endpoint": - return None - return getattr(obj, name, default) +@pytest.mark.asyncio +async def test_search_result_caching() -> None: + """Test that search results are properly cached and retrieved.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + enable_caching=True, + cache_ttl_seconds=10, + ) - with patch( - "autogen_ext.tools.azure._ai_search.getattr", - side_effect=missing_endpoint_side_effect, - ): - with pytest.raises(ValueError, match="OpenAI endpoint must be provided"): - await tool._get_embedding("test query") + mock_results = [{"id": "1", "content": "Test", "@search.score": 0.8}] - tool.search_config.embedding_provider = "unsupported_provider" + with patch.object(tool, "_get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.search.return_value.__aiter__.return_value = mock_results + mock_get_client.return_value = mock_client - def unsupported_provider_side_effect(obj: Any, name: str, default: Any = None) -> Any: - if name == "openai_endpoint": - return "https://test.openai.azure.com" - return getattr(obj, name, default) + result1 = await tool.run("test query") + assert len(result1.results) == 1 + assert mock_client.search.call_count == 1 - with patch( - "autogen_ext.tools.azure._ai_search.getattr", - side_effect=unsupported_provider_side_effect, - ): - with pytest.raises(ValueError, match="Unsupported embedding provider"): - await tool._get_embedding("test query") - finally: - _allow_private_constructor.reset(token) + result2 = await tool.run("test query") + assert len(result2.results) == 1 + assert mock_client.search.call_count == 1 @pytest.mark.asyncio -async def test_azure_openai_with_default_credential() -> None: - """Test Azure OpenAI with DefaultAzureCredential.""" - - mock_azure_openai = AsyncMock() - mock_azure_openai.embeddings.create.return_value = MagicMock(data=[MagicMock(embedding=[0.1, 0.2, 0.3])]) - - mock_credential = MagicMock() - mock_token = MagicMock() - mock_token.token = "mock-token" - mock_credential.get_token.return_value = mock_token - - with ( - patch("openai.AsyncAzureOpenAI") as mock_azure_openai_class, - patch("azure.identity.DefaultAzureCredential", return_value=mock_credential), - patch("autogen_ext.tools.azure._ai_search.getattr") as mock_getattr, - ): - mock_azure_openai_class.return_value = mock_azure_openai - - def side_effect(obj: Any, name: str, default: Any = None) -> Any: - if name == "embedding_provider": - return "azure_openai" - elif name == "embedding_model": - return "text-embedding-ada-002" - elif name == "openai_endpoint": - return "https://test.openai.azure.com" - elif name == "openai_api_version": - return "2023-05-15" - return default - - mock_getattr.side_effect = side_effect - - class TestTool(AzureAISearchTool): - async def _get_embedding(self, query: str) -> List[float]: - return await AzureAISearchTool._get_embedding(self, query) - - token = _allow_private_constructor.set(True) - try: - tool = TestTool( - name="test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), - query_type="vector", - vector_fields=["embedding"], - ) - - token_provider: Optional[Callable[[], str]] = None +async def test_cache_expiration() -> None: + """Test that cached results expire after TTL.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + enable_caching=True, + cache_ttl_seconds=1, + ) - def capture_token_provider( - api_key: Optional[str] = None, - azure_ad_token_provider: Optional[Callable[[], str]] = None, - **kwargs: Any, - ) -> AsyncMock: - nonlocal token_provider - if azure_ad_token_provider: - token_provider = azure_ad_token_provider - return mock_azure_openai + mock_results = [{"id": "1", "content": "Test", "@search.score": 0.8}] - mock_azure_openai_class.side_effect = capture_token_provider + with patch.object(tool, "_get_client") as mock_get_client: + mock_client = AsyncMock() + mock_client.search.return_value.__aiter__.return_value = mock_results + mock_get_client.return_value = mock_client - result = await tool._get_embedding("test query") # pyright: ignore[reportPrivateUsage] - assert result == [0.1, 0.2, 0.3] + await tool.run("test query") + assert mock_client.search.call_count == 1 - assert token_provider is not None - token_provider() - mock_credential.get_token.assert_called_once_with("https://cognitiveservices.azure.com/.default") + await asyncio.sleep(1.1) - mock_azure_openai.embeddings.create.assert_called_once_with( - model="text-embedding-ada-002", input="test query" - ) - finally: - _allow_private_constructor.reset(token) + await tool.run("test query") + assert mock_client.search.call_count == 2 @pytest.mark.asyncio -async def test_schema_property() -> None: - """Test the schema property correctly defines the JSON schema for the tool.""" - tool = ConcreteAzureAISearchTool.create_keyword_search( - name="schema_test", - endpoint="https://test.search.windows.net", - index_name="test-index", - credential=AzureKeyCredential("test-key"), +async def test_search_field_validation() -> None: + """Test validation of search fields configuration.""" + tool = AzureAISearchTool.create_full_text_search( + name="test-search", endpoint=MOCK_ENDPOINT, index_name=MOCK_INDEX, credential=MOCK_CREDENTIAL, search_fields=[] ) + assert tool.search_config.search_fields == [] - schema = tool.schema + tool = AzureAISearchTool.create_full_text_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + search_fields=["content", "content"], + ) + assert tool.search_config.search_fields == ["content", "content"] - assert schema["name"] == "schema_test" - assert "description" in schema - parameters = schema.get("parameters", {}) # pyright: ignore - assert parameters.get("type") == "object" # pyright: ignore +@pytest.mark.asyncio +async def test_api_version_handling() -> None: + """Test handling of different API versions.""" + tool = AzureAISearchTool.create_vector_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=["embedding"], + api_version="2023-11-01", + ) + assert tool.search_config.api_version == "2023-11-01" - properties = parameters.get("properties", {}) # pyright: ignore - assert "query" in properties # pyright: ignore + tool = AzureAISearchTool.create_vector_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=["embedding"], + ) + assert tool.search_config.api_version == DEFAULT_API_VERSION - required = parameters.get("required", []) # pyright: ignore - assert "query" in required # pyright: ignore + with patch("autogen_ext.tools.azure._ai_search.logger") as mock_logger: + AzureAISearchTool.create_vector_search( + name="test-search", + endpoint=MOCK_ENDPOINT, + index_name=MOCK_INDEX, + credential=MOCK_CREDENTIAL, + vector_fields=["embedding"], + api_version="2023-11-01", + ) - assert schema.get("strict") is True # pyright: ignore + mock_logger.warning.assert_called_once() + warning_msg = mock_logger.warning.call_args[0][0] + assert "vector search" in warning_msg.lower() + assert "2023-11-01" in warning_msg