import codecs
import inspect
import json
from datetime import datetime
from typing import Any, Callable, cast, get_args
from fews_openapi_py_client import AuthenticatedClient, Client
from fews_openapi_py_client.types import Unset
from requests import HTTPError
__all__ = ["ApiEndpoint"]
[docs]
class ApiEndpoint:
"""Wraps a single API endpoint with parameter handling and validation."""
endpoint_function: Callable[..., Any]
success_status_codes: frozenset[int] = frozenset({200})
[docs]
def execute(
self,
*,
client: AuthenticatedClient | Client,
**kwargs: Any,
) -> Any:
"""
Execute the API endpoint call.
Args:
client: AuthenticatedClient or Client instance for API calls.
document_format: Format of the returned document.
**kwargs: Keyword arguments for the API call.
Returns:
Parsed response content based on the returned content type.
"""
response = self.endpoint_function(client=client, **kwargs)
if response.status_code not in self.success_status_codes:
self._request_error_handler(response)
return self._parse_response_content(response)
def _get_parameter_models(self) -> dict[str, dict[str, Any]]:
"""
Extract parameter models from the API endpoint function signature.
Identifies enum types and boolean flags from function parameter
annotations, excluding standard types and the 'client' parameter.
Returns:
dict: Mapping of parameter names to model information containing:
- 'is_bool': bool indicating if the parameter is a boolean enum
- 'model': The enum class for the parameter
Raises:
ValueError: If a parameter has unexpected annotation structure.
"""
function_params = inspect.signature(self.endpoint_function).parameters
standard_types = (str, int, float, bool, list, dict, tuple, set, datetime)
parameter_models: dict[str, dict[str, Any]] = {}
for param_name, param in function_params.items():
if param_name == "client":
continue
annotation = param.annotation
args = get_args(annotation)
# Check if argument annotation contains standard types
if self._contains_types(args, standard_types) or not args:
continue
arg_list = list(args)
if Unset in arg_list:
arg_list.remove(Unset)
if not len(arg_list) == 1:
raise ValueError(
f"Expected two annotation arguments, but got"
f" {len(arg_list)} for {param_name}"
)
model = arg_list[0]
if not hasattr(model, "__members__"):
continue
m_dict: dict[str, Any] = {}
if "TRUE" in model.__members__.keys():
m_dict["is_bool"] = True
else:
m_dict["is_bool"] = False
m_dict["model"] = model
parameter_models[param_name] = m_dict
return parameter_models
def _contains_types(
self, args: tuple[Any, ...] | list[Any], check_types: tuple[type[Any], ...]
) -> bool:
"""
Recursively check if any type in args is contained in check_types.
Handles nested generic types like list[str] or dict[str, int].
Args:
args: Tuple of type arguments to check.
check_types: Tuple of types to check against.
Returns:
bool: True if any arg matches a type in check_types, False otherwise.
"""
for arg in args:
if arg in check_types:
return True
nested_args = get_args(arg)
if nested_args and self._contains_types(nested_args, check_types):
return True
return False
def _convert_bools(self, arg: bool) -> str:
"""
Convert a boolean value to the string representation expected by the API.
Args:
arg: Boolean value to convert.
Returns:
str: "true" if arg is truthy, "false" otherwise.
"""
if arg is True:
return "true"
if arg is False:
return "false"
raise ValueError(f"Expected boolean value, got {arg}")
def _parse_response_content(self, response: Any) -> Any:
"""Parse a successful response using its returned content type."""
content_type = response.headers.get("content-type", "")
media_type = content_type.split(";", 1)[0].strip().lower()
if media_type.endswith("json") or media_type.endswith("+json"):
return json.loads(
self._decode_response_body(response.content, content_type)
)
if media_type.startswith("text/") or media_type in {
"application/xml",
"text/xml",
}:
return self._decode_response_body(response.content, content_type)
return cast(bytes, response.content)
def _decode_response_body(self, content: bytes, content_type: str) -> str:
"""Decode response bytes using the declared charset when available."""
encoding = self._get_response_encoding(content_type)
try:
return content.decode(encoding)
except UnicodeDecodeError:
return content.decode(encoding, errors="replace")
def _get_response_encoding(self, content_type: str) -> str:
"""Resolve a response charset parameter to a Python codec name."""
for parameter in content_type.split(";")[1:]:
name, separator, value = parameter.partition("=")
if separator and name.strip().lower() == "charset":
encoding = value.strip().strip('"').strip("'")
if encoding:
try:
return codecs.lookup(encoding).name
except LookupError:
break
return "utf-8"
def _request_error_handler(self, response: Any) -> None:
"""Handle request errors by raising exceptions for non-200 responses."""
content_type = response.headers.get("content-type", "")
response_body = self._decode_response_body(response.content, content_type)
raise HTTPError(
f"Request failed with status code {response.status_code}: {response_body}"
)