test_openai_retry_transient.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. """Regression tests for retrying transient OpenAI failures.
  2. Covers:
  3. * HTTP 5xx (InternalServerError) is retried on both complete and embed.
  4. * Transient "could not parse JSON body" 400s are converted to a retryable
  5. TransientBadRequestError, while genuine 400s fail fast.
  6. """
  7. from types import SimpleNamespace
  8. from unittest.mock import AsyncMock, patch
  9. import httpx
  10. import pytest
  11. from openai import BadRequestError, InternalServerError
  12. from lightrag.llm.openai import (
  13. TransientBadRequestError,
  14. openai_complete_if_cache,
  15. openai_embed,
  16. )
  17. def _retry_exception_types(func) -> set[type]:
  18. """Collect the exception types a tenacity-decorated func retries on."""
  19. types: set[type] = set()
  20. def _walk(retry_obj):
  21. # retry_any / retry_all expose `.retries`; retry_if_exception_type
  22. # exposes `.exception_types` (a single type or a tuple of types).
  23. for child in getattr(retry_obj, "retries", ()):
  24. _walk(child)
  25. exc_types = getattr(retry_obj, "exception_types", ())
  26. if isinstance(exc_types, type):
  27. exc_types = (exc_types,)
  28. types.update(exc_types)
  29. # openai_embed is wrapped by @wrap_embedding_func_with_attrs; the
  30. # tenacity-decorated callable is on `.func`.
  31. target = getattr(func, "func", func)
  32. _walk(target.retry.retry)
  33. return types
  34. def _make_bad_request(message: str) -> BadRequestError:
  35. request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
  36. response = httpx.Response(status_code=400, request=request)
  37. return BadRequestError(message, response=response, body=None)
  38. def _make_error_client(error: Exception) -> SimpleNamespace:
  39. return SimpleNamespace(
  40. chat=SimpleNamespace(
  41. completions=SimpleNamespace(create=AsyncMock(side_effect=error))
  42. ),
  43. close=AsyncMock(),
  44. )
  45. @pytest.mark.offline
  46. def test_complete_retries_5xx_and_transient_400():
  47. retried = _retry_exception_types(openai_complete_if_cache)
  48. assert InternalServerError in retried
  49. assert TransientBadRequestError in retried
  50. @pytest.mark.offline
  51. def test_embed_retries_5xx():
  52. assert InternalServerError in _retry_exception_types(openai_embed)
  53. @pytest.mark.offline
  54. @pytest.mark.asyncio
  55. async def test_transient_json_parse_400_is_wrapped():
  56. """A 'could not parse JSON body' 400 becomes a retryable wrapper."""
  57. err = _make_bad_request(
  58. "Error code: 400 - We could not parse the JSON body of your request."
  59. )
  60. fake_client = _make_error_client(err)
  61. # Call the undecorated coroutine to exercise the handler exactly once
  62. # (bypasses the tenacity retry loop and its waits).
  63. with patch(
  64. "lightrag.llm.openai.create_openai_async_client", return_value=fake_client
  65. ):
  66. with pytest.raises(TransientBadRequestError):
  67. await openai_complete_if_cache.__wrapped__(
  68. model="gpt-4o-mini", prompt="hello"
  69. )
  70. fake_client.close.assert_awaited()
  71. @pytest.mark.offline
  72. @pytest.mark.asyncio
  73. async def test_genuine_400_fails_fast():
  74. """A non-parse 400 (e.g. bad params) is not wrapped, propagates, and closes the client."""
  75. err = _make_bad_request("Error code: 400 - Invalid value for 'temperature'.")
  76. fake_client = _make_error_client(err)
  77. with patch(
  78. "lightrag.llm.openai.create_openai_async_client", return_value=fake_client
  79. ):
  80. with pytest.raises(BadRequestError):
  81. await openai_complete_if_cache.__wrapped__(
  82. model="gpt-4o-mini", prompt="hello"
  83. )
  84. # The non-transient 400 path must still close the underlying httpx client
  85. # to avoid connection leaks in validation-heavy/misconfigured runs.
  86. fake_client.close.assert_awaited()