Source code for agent_inspect.clients.azure_openai_client

import os
import requests
import asyncio
import backoff

from openai import AzureOpenAI, APIStatusError, RateLimitError
from typing import Any, Dict
import logging

from agent_inspect.clients.llm_client import LLMClient
from agent_inspect.models.llm_payload import LLMPayload
from agent_inspect.models.llm_response import LLMResponse

from agent_inspect.metrics.constants import STATUS_200, STATUS_500, STATUS_429, MAX_RETRY_ATTEMPTS_EXCEEDED

logger = logging.getLogger(__name__)


[docs] def backoff_handler(details): """ Custom handler for backoff events to log messages with extra attributes. """ message = ( f"Backing off {details['target'].__name__} " f"(args={details['args']}, kwargs={details['kwargs']}) " f"for {details['wait']:.1f}s " f"(tries={details['tries']}, elapsed={details['elapsed']:.1f}s, " f"exception: {details['exception']})" ) logger.warning(message)
[docs] def give_up_handler(details): """ Custom handler for when max retries are reached. """ message = ( f"Max retries reached for {details['target'].__name__} " f"(args={details['args']}, kwargs={details['kwargs']}, " f"elapsed={details['elapsed']:.1f}s, exception: {details['exception']})" ) logger.error(message)
[docs] class AzureOpenAIClient(LLMClient): """ Client class providing connection to the Azure OpenAI Service. Need to set the following environment variables: ``AZURE_API_VERSION``, ``AZURE_API_BASE``, ``AZURE_API_KEY``. :param model: the selected Azure OpenAI model which will receive the prompt. This is the deployment name in Azure. :param max_tokens: the maximum number of tokens allowed for the LLM to generate. :param temperature: the temperature setting for LLM model. Default to ``0``. """ def __init__(self, model: str, max_tokens: int, temperature: float = 0): self.model = model self.max_tokens = max_tokens self.chat_client = AzureOpenAI( api_version=os.environ['AZURE_API_VERSION'], azure_endpoint=os.environ['AZURE_API_BASE'], api_key=os.environ['AZURE_API_KEY'], ) self.temperature = temperature
[docs] @backoff.on_exception( backoff.expo, (APIStatusError, RateLimitError, requests.exceptions.Timeout, requests.exceptions.ConnectionError, requests.exceptions.RequestException), max_tries=10, max_time=300, jitter=None, on_backoff=backoff_handler, on_giveup=give_up_handler, giveup=lambda e: isinstance(e, APIStatusError) and e.status_code in [400, 401, 403, 404] ) async def make_llm_request_with_retry(self, prompt: str): messages = [{"role": "user", "content": prompt}] azure_openai_response = self.chat_client.chat.completions.create( model=self.model, messages=messages, max_tokens=self.max_tokens, temperature=self.temperature ) return azure_openai_response
[docs] async def make_llm_request(self, prompt: str) -> LLMResponse: """ Returns LLM completion after sending a prompt to the selected the model. Uses an exponential backoff retry mechanism for transient failures. :param prompt: the provided prompt to send to the model. :return: :obj:`~agent_inspect.models.llm_response.LLMResponse` object containing status code, completion and error message. """ # Backoff decorator handles retries; this try-catch handles final failures try: azure_openai_response = await self.make_llm_request_with_retry(prompt) return LLMResponse(status=STATUS_200, completion=azure_openai_response.choices[0].message.content) except RateLimitError: return LLMResponse(status=STATUS_429, completion="", error_message=MAX_RETRY_ATTEMPTS_EXCEEDED) except APIStatusError as e: # Non-retryable errors (4xx) or final failure after all retries return LLMResponse( status=e.status_code, completion="", error_message=f"Azure OpenAI API Error: {e.message}" ) except Exception as e: # Catch any other unexpected errors return LLMResponse( status=STATUS_500, completion="", error_message=f"Unexpected error: {str(e)}" )
[docs] async def make_llm_requests(self, prompts: list[str]) -> list[LLMResponse]: """ Returns LLM completion after sending a prompt to the selected the model. :param prompts: a list of provided prompts to send to the model. :return: a :obj:`~typing.List` [:obj:`~agent_inspect.models.llm_response.LLMResponse`] object containing status codes, completions and error messages. """ responses = await asyncio.gather(*(self.make_llm_request(prompt) for prompt in prompts)) return list(responses)
[docs] def convert_payload_to_raw_request(self, payload: LLMPayload) -> Dict[str, Any]: raw_request: Dict[str, Any] = {} messages = [] if payload.system_prompt: messages.append({"role": "system", "content": payload.system_prompt}) messages.append({"role": "user", "content": payload.user_prompt}) raw_request["model"] = payload.model if payload.model else self.model raw_request["messages"] = messages raw_request["temperature"] = payload.temperature if payload.temperature else self.temperature raw_request["max_tokens"] = payload.max_tokens if payload.max_tokens else self.max_tokens if payload.structured_output: raw_request["response_format"] = payload.structured_output return raw_request
[docs] @backoff.on_exception( backoff.expo, (APIStatusError, requests.exceptions.Timeout, requests.exceptions.ConnectionError, requests.exceptions.RequestException), max_tries=10, max_time=300, jitter=None, on_backoff=backoff_handler, on_giveup=give_up_handler, giveup=lambda e: isinstance(e, APIStatusError) and e.status_code in [400, 401, 403, 404] ) async def make_request_with_payload_using_retry(self, payload: LLMPayload): raw_request = self.convert_payload_to_raw_request(payload) azure_response = self.chat_client.chat.completions.create(**raw_request) return azure_response
[docs] async def make_request_with_payload(self, payload: LLMPayload) -> LLMResponse: """ Returns LLM completion after sending a payload to the selected the model. :param payload: the provided payload to send to the model. :return: :obj:`~agent_inspect.models.llm_response.LLMResponse` object containing status code, completion and error message. """ try: azure_response = await self.make_request_with_payload_using_retry(payload) response = LLMResponse(status=STATUS_200, completion=azure_response.choices[0].message.content) except RateLimitError: response = LLMResponse(status=STATUS_429, completion="", error_message=MAX_RETRY_ATTEMPTS_EXCEEDED) except APIStatusError as e: response = LLMResponse(status=e.status_code, completion="", error_message=e.message) except Exception as e: response = LLMResponse(status=STATUS_500, completion="", error_message=str(e)) return response