import logging

from openai import APIConnectionError, APIError, APITimeoutError, RateLimitError
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import exception_handler


logger = logging.getLogger("chat.api")


def custom_exception_handler(exc, context):
    response = exception_handler(exc, context)

    if response is not None:
        response.data = {
            "error": {
                "code": _extract_error_code(exc, response.status_code),
                "message": _extract_error_message(response),
            }
        }
        return response

    if isinstance(exc, RateLimitError):
        status_code = status.HTTP_429_TOO_MANY_REQUESTS
        message = "OpenAI rate limit reached. Try again shortly."
    elif isinstance(exc, APITimeoutError):
        status_code = status.HTTP_504_GATEWAY_TIMEOUT
        message = "OpenAI request timed out."
    elif isinstance(exc, (APIConnectionError, APIError)):
        status_code = status.HTTP_502_BAD_GATEWAY
        message = "Upstream AI provider error."
    else:
        status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
        message = "Internal server error."

    logger.exception("Unhandled API exception", exc_info=exc)
    return Response(
        {"error": {"code": "server_error", "message": message}},
        status=status_code,
    )


def _extract_error_code(exc, status_code):
    if status_code == status.HTTP_400_BAD_REQUEST:
        return "bad_request"
    if status_code == status.HTTP_401_UNAUTHORIZED:
        return "authentication_failed"
    if status_code == status.HTTP_403_FORBIDDEN:
        return "permission_denied"
    if status_code == status.HTTP_404_NOT_FOUND:
        return "not_found"
    if status_code == status.HTTP_429_TOO_MANY_REQUESTS:
        return "throttled"

    return getattr(exc, "default_code", "api_error")


def _extract_error_message(response):
    data = response.data
    if isinstance(data, dict):
        if "detail" in data:
            return str(data["detail"])

        first_key = next(iter(data), None)
        if first_key is not None:
            first_value = data[first_key]
            if isinstance(first_value, list) and first_value:
                return str(first_value[0])
            return str(first_value)

    return "Request failed."
