test_cwe89_opensearch_injection.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. """
  2. PoC test: CWE-89 OpenSearch injection via unsanitized entity names in query construction.
  3. The test validates that:
  4. 1. Wildcard special characters (*, ?) in user input to search_labels are escaped
  5. before being used in OpenSearch wildcard queries, preventing DoS via expensive
  6. wildcard patterns.
  7. 2. PPL escape handles control characters and additional metacharacters beyond
  8. just backslash and single-quote.
  9. Run with: pytest tests/kg/opensearch_impl/test_cwe89_opensearch_injection.py -v
  10. """
  11. import re
  12. import pytest
  13. from contextlib import asynccontextmanager
  14. from unittest.mock import AsyncMock, patch
  15. import numpy as np
  16. pytest.importorskip(
  17. "opensearchpy",
  18. reason="opensearchpy is required for OpenSearch storage tests",
  19. )
  20. from lightrag.kg.opensearch_impl import (
  21. OpenSearchGraphStorage,
  22. ClientManager,
  23. )
  24. pytestmark = pytest.mark.offline
  25. @asynccontextmanager
  26. async def _mock_lock():
  27. yield
  28. def _mock_lock_factory():
  29. return _mock_lock()
  30. @pytest.fixture(autouse=True)
  31. def patch_data_init_lock():
  32. with patch(
  33. "lightrag.kg.opensearch_impl.get_data_init_lock", side_effect=_mock_lock_factory
  34. ):
  35. yield
  36. class MockEmbeddingFunc:
  37. def __init__(self, dim=128):
  38. self.embedding_dim = dim
  39. self.max_token_size = 512
  40. self.model_name = "mock-embed"
  41. async def __call__(self, texts, **kwargs):
  42. return np.random.rand(len(texts), self.embedding_dim).astype(np.float32)
  43. @pytest.fixture
  44. def global_config():
  45. return {
  46. "embedding_batch_num": 10,
  47. "max_graph_nodes": 1000,
  48. }
  49. @pytest.fixture
  50. def embed_func():
  51. return MockEmbeddingFunc()
  52. def _make_client():
  53. from opensearchpy import AsyncOpenSearch
  54. client = AsyncMock(spec=AsyncOpenSearch)
  55. # opensearchpy decorates client methods with @query_params, which hides their
  56. # coroutine nature from inspect.iscoroutinefunction. AsyncMock(spec=...) then
  57. # creates plain MagicMocks for them, so callers awaiting client.search(...) hit
  58. # "object dict can't be used in 'await' expression". Force AsyncMock explicitly.
  59. client.search = AsyncMock()
  60. client.indices = AsyncMock()
  61. client.indices.exists = AsyncMock(return_value=True)
  62. client.indices.refresh = AsyncMock()
  63. client.transport = AsyncMock()
  64. return client
  65. @pytest.fixture
  66. def graph_storage(global_config, embed_func):
  67. with patch.object(ClientManager, "get_client") as mock_get:
  68. client = _make_client()
  69. mock_get.return_value = client
  70. storage = OpenSearchGraphStorage(
  71. namespace="test_graph",
  72. global_config=global_config,
  73. embedding_func=embed_func,
  74. )
  75. storage.client = client
  76. storage._indices_ready = True
  77. storage._ppl_graphlookup_available = True
  78. yield storage
  79. class TestWildcardInjection:
  80. """Test that wildcard special chars are escaped in search_labels."""
  81. @pytest.mark.asyncio
  82. async def test_wildcard_chars_escaped_in_search_labels(self, graph_storage):
  83. """Wildcard metacharacters *, ? in user input must be escaped."""
  84. client = graph_storage.client
  85. # Setup mock to return empty results
  86. client.search.return_value = {"hits": {"hits": []}}
  87. # Malicious query with wildcard chars that could cause expensive patterns
  88. malicious_query = "test*?foo"
  89. await graph_storage.search_labels(malicious_query)
  90. # Inspect the query body that was sent to OpenSearch
  91. assert client.search.called, "search should have been called"
  92. call_kwargs = client.search.call_args
  93. body = call_kwargs.kwargs.get("body") or call_kwargs[1].get("body")
  94. # Extract the wildcard clause
  95. should_clauses = body["query"]["bool"]["should"]
  96. wildcard_clause = None
  97. for clause in should_clauses:
  98. if "wildcard" in clause:
  99. wildcard_clause = clause["wildcard"]["entity_id"]["value"]
  100. break
  101. assert wildcard_clause is not None, "wildcard clause should exist"
  102. # The wildcard value should NOT contain unescaped * or ? from the user input
  103. # The outer * wrapping is fine (those are the intentional wildcards),
  104. # but the inner user-provided * and ? must be escaped
  105. # Expected: *test\*\?foo* (with the user's * and ? escaped)
  106. inner_value = wildcard_clause[1:-1] # strip leading and trailing *
  107. assert (
  108. "\\*" in inner_value
  109. ), f"User's '*' should be escaped as '\\*' in wildcard, got: {wildcard_clause}"
  110. assert (
  111. "\\?" in inner_value
  112. ), f"User's '?' should be escaped as '\\?' in wildcard, got: {wildcard_clause}"
  113. @pytest.mark.asyncio
  114. async def test_wildcard_heavy_pattern_not_exploitable(self, graph_storage):
  115. """A series of ? chars should be escaped, not passed raw to OpenSearch."""
  116. client = graph_storage.client
  117. client.search.return_value = {"hits": {"hits": []}}
  118. # Attack: many single-char wildcards cause exponential matching
  119. attack_query = "?" * 50
  120. await graph_storage.search_labels(attack_query)
  121. call_kwargs = client.search.call_args
  122. body = call_kwargs.kwargs.get("body") or call_kwargs[1].get("body")
  123. should_clauses = body["query"]["bool"]["should"]
  124. wildcard_clause = None
  125. for clause in should_clauses:
  126. if "wildcard" in clause:
  127. wildcard_clause = clause["wildcard"]["entity_id"]["value"]
  128. break
  129. # None of the user's ? should appear as unescaped wildcards
  130. # The value between the outer * delimiters should have all ? escaped
  131. inner = wildcard_clause[1:-1]
  132. # Count unescaped ? (i.e., ? not preceded by \)
  133. unescaped_q = re.findall(r"(?<!\\)\?", inner)
  134. assert (
  135. len(unescaped_q) == 0
  136. ), f"Found {len(unescaped_q)} unescaped '?' in wildcard pattern: {wildcard_clause}"
  137. @pytest.mark.asyncio
  138. async def test_backslash_escaped_in_wildcard(self, graph_storage):
  139. """Backslashes in user input must be double-escaped for the wildcard query."""
  140. client = graph_storage.client
  141. client.search.return_value = {"hits": {"hits": []}}
  142. attack_query = "test\\*"
  143. await graph_storage.search_labels(attack_query)
  144. call_kwargs = client.search.call_args
  145. body = call_kwargs.kwargs.get("body") or call_kwargs[1].get("body")
  146. should_clauses = body["query"]["bool"]["should"]
  147. wildcard_clause = None
  148. for clause in should_clauses:
  149. if "wildcard" in clause:
  150. wildcard_clause = clause["wildcard"]["entity_id"]["value"]
  151. break
  152. # The backslash should be escaped first, then the * — so we get \\\\\\*
  153. # In the final pattern between outer *...*, user's \ becomes \\ and * becomes \*
  154. inner = wildcard_clause[1:-1]
  155. assert (
  156. "\\\\" in inner or "\\*" in inner
  157. ), f"Backslash and * from user should be escaped in wildcard: {wildcard_clause}"
  158. class TestPPLInjection:
  159. """Test that PPL string escape handles additional metacharacters."""
  160. def test_escape_ppl_basic_quote(self, graph_storage):
  161. """Single quotes should be escaped."""
  162. result = graph_storage._escape_ppl("it's a test")
  163. assert "'" not in result.replace(
  164. "\\'", ""
  165. ), f"Unescaped quote found in: {result}"
  166. def test_escape_ppl_backslash(self, graph_storage):
  167. """Backslashes should be escaped."""
  168. result = graph_storage._escape_ppl("test\\path")
  169. assert result == "test\\\\path"
  170. def test_escape_ppl_newline_and_control_chars(self, graph_storage):
  171. """Newlines and control characters should be escaped/stripped."""
  172. result = graph_storage._escape_ppl("line1\nline2\rline3\t")
  173. # Control chars should either be stripped or escaped — no raw newlines
  174. assert "\n" not in result, f"Raw newline in PPL literal: {repr(result)}"
  175. assert "\r" not in result, f"Raw carriage return in PPL literal: {repr(result)}"
  176. def test_escape_ppl_pipe_in_quotes_safe(self, graph_storage):
  177. """Pipe character inside a quoted string literal poses no injection risk."""
  178. result = graph_storage._escape_ppl("entity | stats count()")
  179. assert isinstance(result, str)
  180. if __name__ == "__main__":
  181. pytest.main([__file__, "-v"])