Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,16 @@ def __init__(self, model_config, *, threshold=3, credential=None, **kwargs):
current_dir = os.path.dirname(__file__)
prompty_path = os.path.join(current_dir, self._PROMPTY_FILE_NO_QUERY) # Default to no query

self._higher_is_better = True
super().__init__(
model_config=model_config,
prompty_file=prompty_path,
result_key=self._RESULT_KEY,
threshold=threshold,
credential=credential,
_higher_is_better=self._higher_is_better,
_higher_is_better=True,
**kwargs,
)
self._model_config = model_config
self.threshold = threshold
# Needs to be set because it's used in call method to re-validate prompt if `query` is provided

@overload
Expand Down Expand Up @@ -206,26 +204,53 @@ def __call__( # pylint: disable=docstring-missing-param

return super().__call__(*args, **kwargs)

def _ensure_query_prompty_loaded(self):
"""Switch to the query prompty file if not already loaded."""
def _load_prompty_file(self, prompty_filename: str):
"""Load the specified prompty file if not already loaded.

:param prompty_filename: The name of the prompty file to load.
:type prompty_filename: str
"""
if self._prompty_file.endswith(prompty_filename):
return # Already using the correct prompty file

current_dir = os.path.dirname(__file__)
prompty_path = os.path.join(current_dir, self._PROMPTY_FILE_WITH_QUERY)
prompty_path = os.path.join(current_dir, prompty_filename)

self._prompty_file = prompty_path
prompty_model_config = construct_prompty_model_config(
validate_model_config(self._model_config),
self._DEFAULT_OPEN_API_VERSION,
UserAgentSingleton().value,
)
self._flow = AsyncPrompty.load(source=self._prompty_file, model=prompty_model_config)
self._flow = AsyncPrompty.load(
source=self._prompty_file, model=prompty_model_config, is_reasoning_model=self._is_reasoning_model
)

def _ensure_query_prompty_loaded(self):
"""Switch to the query prompty file if not already loaded."""
self._load_prompty_file(self._PROMPTY_FILE_WITH_QUERY)

def _ensure_no_query_prompty_loaded(self):
"""Switch to the no-query prompty file if not already loaded."""
self._load_prompty_file(self._PROMPTY_FILE_NO_QUERY)

def _has_context(self, eval_input: dict) -> bool:
"""
Return True if eval_input contains a non-empty 'context' field.
Treats None, empty strings, empty lists, and lists of empty strings as no context.
"""
context = eval_input.get("context", None)
return self._validate_context(context)

def _validate_context(self, context) -> bool:
"""
Validate if the provided context is non-empty and meaningful.
Treats None, empty strings, empty lists, and lists of empty strings as no context.
:param context: The context to validate
:type context: Union[str, List, None]
:return: True if context is valid and non-empty, False otherwise
:rtype: bool
"""
if not context:
return False
if context == "<>": # Special marker for no context
Expand All @@ -239,8 +264,10 @@ def _has_context(self, eval_input: dict) -> bool:
@override
async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]:
if eval_input.get("query", None) is None:
self._ensure_no_query_prompty_loaded()
return await super()._do_eval(eval_input)

self._ensure_query_prompty_loaded()
contains_context = self._has_context(eval_input)

simplified_query = simplify_messages(eval_input["query"], drop_tool_calls=contains_context)
Expand Down Expand Up @@ -272,22 +299,27 @@ async def _real_call(self, **kwargs):
return {
self._result_key: self._NOT_APPLICABLE_RESULT,
f"{self._result_key}_result": "pass",
f"{self._result_key}_threshold": self.threshold,
f"{self._result_key}_threshold": self._threshold,
f"{self._result_key}_reason": f"Supported tools were not called. Supported tools for groundedness are {self._SUPPORTED_TOOLS}.",
}
else:
raise ex

def _is_single_entry(self, value):
"""Determine if the input value represents a single entry; otherwise return False."""
if isinstance(value, str):
return True
if isinstance(value, list) and len(value) == 1:
return True
return False

def _convert_kwargs_to_eval_input(self, **kwargs):
if kwargs.get("context") or kwargs.get("conversation"):
return super()._convert_kwargs_to_eval_input(**kwargs)
query = kwargs.get("query")
response = kwargs.get("response")
tool_definitions = kwargs.get("tool_definitions")

if query and self._prompty_file != self._PROMPTY_FILE_WITH_QUERY:
self._ensure_query_prompty_loaded()

if (not query) or (not response): # or not tool_definitions:
msg = f"{type(self).__name__}: Either 'conversation' or individual inputs must be provided. For Agent groundedness 'query' and 'response' are required."
raise EvaluationException(
Expand All @@ -298,7 +330,16 @@ def _convert_kwargs_to_eval_input(self, **kwargs):
)
context = self._get_context_from_agent_response(response, tool_definitions)

filtered_response = self._filter_file_search_results(response)
if not self._validate_context(context) and self._is_single_entry(response) and self._is_single_entry(query):
msg = f"{type(self).__name__}: No valid context provided or could be extracted from the query or response."
raise EvaluationException(
message=msg,
blame=ErrorBlame.USER_ERROR,
category=ErrorCategory.NOT_APPLICABLE,
target=ErrorTarget.GROUNDEDNESS_EVALUATOR,
)

filtered_response = self._filter_file_search_results(response) if self._validate_context(context) else response
return super()._convert_kwargs_to_eval_input(response=filtered_response, context=context, query=query)

def _filter_file_search_results(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
Expand Down
Loading