| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- from types import SimpleNamespace
- from unittest.mock import AsyncMock, patch
- import pytest
- from lightrag.llm.openai import openai_complete_if_cache
- def _make_completion(content: str, finish_reason: str = "stop"):
- return SimpleNamespace(
- choices=[
- SimpleNamespace(
- finish_reason=finish_reason,
- message=SimpleNamespace(
- content=content,
- parsed=None,
- reasoning_content="",
- ),
- )
- ],
- usage=SimpleNamespace(
- prompt_tokens=10,
- completion_tokens=20,
- total_tokens=30,
- ),
- )
- def _make_fake_client(completion):
- return SimpleNamespace(
- chat=SimpleNamespace(
- completions=SimpleNamespace(
- create=AsyncMock(return_value=completion),
- )
- ),
- close=AsyncMock(),
- )
- class _FakeAsyncStream:
- def __init__(self, chunks):
- self._chunks = iter(chunks)
- def __aiter__(self):
- return self
- async def __anext__(self):
- try:
- return next(self._chunks)
- except StopIteration:
- raise StopAsyncIteration
- async def aclose(self):
- return None
- def _make_stream_chunk(content=None, reasoning_content=None):
- return SimpleNamespace(
- choices=[
- SimpleNamespace(
- delta=SimpleNamespace(
- content=content,
- reasoning_content=reasoning_content,
- )
- )
- ]
- )
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_length_finish_reason_returns_raw_content():
- """Truncated responses (finish_reason='length') still yield raw content.
- After the dispatch simplification, we no longer rely on the typed
- ``LengthFinishReasonError`` path — ``create()`` returns the partial
- content unchanged and upstream tolerant JSON parsing handles it.
- """
- raw_json = (
- '{"entities":[{"name":"Alice","type":"Person",'
- '"description":"Founder"}],"relationships":[]}'
- )
- completion = _make_completion(raw_json, finish_reason="length")
- fake_client = _make_fake_client(completion)
- with patch(
- "lightrag.llm.openai.create_openai_async_client",
- return_value=fake_client,
- ):
- result = await openai_complete_if_cache(
- model="test-model",
- prompt="Extract entities",
- response_format={"type": "json_object"},
- max_completion_tokens=128,
- )
- assert result == raw_json
- fake_client.chat.completions.create.assert_awaited_once()
- fake_client.close.assert_awaited_once()
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_json_object_response_format_forwarded_to_create():
- completion = _make_completion(
- '{"high_level_keywords":["AI"],"low_level_keywords":["RAG"]}'
- )
- fake_client = _make_fake_client(completion)
- with patch(
- "lightrag.llm.openai.create_openai_async_client",
- return_value=fake_client,
- ):
- result = await openai_complete_if_cache(
- model="test-model",
- prompt="Extract keywords",
- response_format={"type": "json_object"},
- )
- assert result == '{"high_level_keywords":["AI"],"low_level_keywords":["RAG"]}'
- fake_client.chat.completions.create.assert_awaited_once()
- assert fake_client.chat.completions.create.await_args.kwargs["response_format"] == {
- "type": "json_object"
- }
- fake_client.close.assert_awaited_once()
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_legacy_entity_extraction_emits_deprecation_warning():
- completion = _make_completion('{"entities":[],"relationships":[]}')
- fake_client = _make_fake_client(completion)
- with patch(
- "lightrag.llm.openai.create_openai_async_client",
- return_value=fake_client,
- ):
- with pytest.warns(DeprecationWarning):
- await openai_complete_if_cache(
- model="test-model",
- prompt="Extract entities",
- entity_extraction=True,
- )
- fake_client.chat.completions.create.assert_awaited_once()
- assert fake_client.chat.completions.create.await_args.kwargs["response_format"] == {
- "type": "json_object"
- }
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_legacy_keyword_extraction_emits_deprecation_warning():
- completion = _make_completion('{"high_level_keywords":[],"low_level_keywords":[]}')
- fake_client = _make_fake_client(completion)
- with patch(
- "lightrag.llm.openai.create_openai_async_client",
- return_value=fake_client,
- ):
- with pytest.warns(DeprecationWarning):
- await openai_complete_if_cache(
- model="test-model",
- prompt="Extract keywords",
- keyword_extraction=True,
- )
- fake_client.chat.completions.create.assert_awaited_once()
- assert fake_client.chat.completions.create.await_args.kwargs["response_format"] == {
- "type": "json_object"
- }
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_typed_response_format_is_rejected():
- completion = _make_completion("{}")
- fake_client = _make_fake_client(completion)
- class FakeSchemaModel:
- pass
- with patch(
- "lightrag.llm.openai.create_openai_async_client",
- return_value=fake_client,
- ):
- with pytest.raises(TypeError, match="typed/Pydantic"):
- await openai_complete_if_cache(
- model="test-model",
- prompt="Extract entities",
- response_format=FakeSchemaModel,
- )
- fake_client.chat.completions.create.assert_not_awaited()
- fake_client.close.assert_not_awaited()
- @pytest.mark.offline
- @pytest.mark.asyncio
- async def test_streaming_structured_output_disables_cot():
- fake_stream = _FakeAsyncStream(
- [
- _make_stream_chunk(reasoning_content="this should not be included"),
- _make_stream_chunk(content='{"answer":"ok"}'),
- ]
- )
- fake_client = _make_fake_client(fake_stream)
- with patch(
- "lightrag.llm.openai.create_openai_async_client",
- return_value=fake_client,
- ):
- stream = await openai_complete_if_cache(
- model="test-model",
- prompt="Extract entities",
- stream=True,
- enable_cot=True,
- response_format={"type": "json_object"},
- )
- chunks = []
- async for chunk in stream:
- chunks.append(chunk)
- assert "".join(chunks) == '{"answer":"ok"}'
- fake_client.close.assert_awaited_once()
|