test_document_routes_paginated.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import importlib
  2. import sys
  3. from types import SimpleNamespace
  4. import pytest
  5. from fastapi import FastAPI
  6. from fastapi.testclient import TestClient
  7. _original_argv = sys.argv[:]
  8. sys.argv = [sys.argv[0]]
  9. _document_routes = importlib.import_module("lightrag.api.routers.document_routes")
  10. _base = importlib.import_module("lightrag.base")
  11. sys.argv = _original_argv
  12. create_document_routes = _document_routes.create_document_routes
  13. DocProcessingStatus = _base.DocProcessingStatus
  14. DocStatus = _base.DocStatus
  15. DocStatusStorage = _base.DocStatusStorage
  16. pytestmark = pytest.mark.offline
  17. def _doc(status: DocStatus, suffix: str) -> DocProcessingStatus:
  18. return DocProcessingStatus(
  19. content_summary=f"{status.value} summary",
  20. content_length=10,
  21. file_path=f"{suffix}.pdf",
  22. status=status,
  23. created_at="2024-01-01T00:00:00+00:00",
  24. updated_at="2024-01-01T00:00:00+00:00",
  25. metadata={},
  26. )
  27. class _FakeDocStatusStorage:
  28. def __init__(self):
  29. self.docs = {
  30. "processed-doc": _doc(DocStatus.PROCESSED, "processed"),
  31. "parsing-doc": _doc(DocStatus.PARSING, "parsing"),
  32. "analyzing-doc": _doc(DocStatus.ANALYZING, "analyzing"),
  33. }
  34. async def get_docs_paginated(
  35. self,
  36. status_filter=None,
  37. status_filters=None,
  38. page=1,
  39. page_size=50,
  40. sort_field="updated_at",
  41. sort_direction="desc",
  42. ):
  43. selected_statuses = DocStatusStorage.resolve_status_filter_values(
  44. status_filter=status_filter,
  45. status_filters=status_filters,
  46. )
  47. documents = [
  48. (doc_id, doc)
  49. for doc_id, doc in self.docs.items()
  50. if selected_statuses is None or doc.status.value in selected_statuses
  51. ]
  52. return documents[:page_size], len(documents)
  53. async def get_all_status_counts(self):
  54. return {"processed": 1, "parsing": 1, "analyzing": 1}
  55. _fake_doc_status = _FakeDocStatusStorage()
  56. _app = FastAPI()
  57. _app.include_router(
  58. create_document_routes(
  59. SimpleNamespace(doc_status=_fake_doc_status),
  60. SimpleNamespace(),
  61. api_key="test-key",
  62. )
  63. )
  64. _client = TestClient(_app)
  65. _headers = {"X-API-Key": "test-key"}
  66. def test_documents_paginated_accepts_status_filter():
  67. response = _client.post(
  68. "/documents/paginated",
  69. headers=_headers,
  70. json={
  71. "status_filter": "processed",
  72. "page": 1,
  73. "page_size": 10,
  74. "sort_field": "updated_at",
  75. "sort_direction": "desc",
  76. },
  77. )
  78. assert response.status_code == 200
  79. payload = response.json()
  80. assert payload["pagination"]["total_count"] == 1
  81. assert [doc["id"] for doc in payload["documents"]] == ["processed-doc"]
  82. def test_documents_paginated_status_filters_override_status_filter():
  83. response = _client.post(
  84. "/documents/paginated",
  85. headers=_headers,
  86. json={
  87. "status_filter": "processed",
  88. "status_filters": ["parsing", "analyzing"],
  89. "page": 1,
  90. "page_size": 10,
  91. "sort_field": "updated_at",
  92. "sort_direction": "desc",
  93. },
  94. )
  95. assert response.status_code == 200
  96. payload = response.json()
  97. assert payload["pagination"]["total_count"] == 2
  98. assert [doc["id"] for doc in payload["documents"]] == [
  99. "parsing-doc",
  100. "analyzing-doc",
  101. ]