test_workspace_sanitization.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. """
  2. Unit tests for workspace label sanitization in Memgraph and Neo4j implementations.
  3. This module tests that `_get_workspace_label()` properly sanitizes workspace names
  4. to prevent Cypher injection via the LIGHTRAG-WORKSPACE HTTP header.
  5. It verifies that we preserve non-alphanumeric characters for 1-to-1 workspace mapping
  6. while successfully neutralizing Cypher injection by escaping backticks.
  7. This test is designed to be dependency-independent by extracting the logic directly
  8. from the source files, as the full LightRAG package has many AI-related dependencies.
  9. References: GitHub Issue #2698
  10. """
  11. import re
  12. import os
  13. import pytest
  14. # Mark all tests as offline (no external dependencies)
  15. pytestmark = pytest.mark.offline
  16. def get_actual_sanitization_logic():
  17. """Extract the sanitization logic from the source files to ensure we test the real code."""
  18. base_path = os.path.dirname(
  19. os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  20. )
  21. files = [
  22. os.path.join(base_path, "lightrag/kg/memgraph_impl.py"),
  23. os.path.join(base_path, "lightrag/kg/neo4j_impl.py"),
  24. ]
  25. logics = []
  26. for file_path in files:
  27. with open(file_path, "r", encoding="utf-8") as f:
  28. content = f.read()
  29. # Find the _get_workspace_label method body
  30. # We look for the specific line: return workspace.replace("`", "``")
  31. match = re.search(r"return workspace\.replace\(\"`\", \"``\"\)", content)
  32. if not match:
  33. raise RuntimeError(f"Could not find sanitization logic in {file_path}")
  34. logics.append(file_path)
  35. # All backends should have identical logic for this helper
  36. def sanitize(workspace: str) -> str:
  37. safe = workspace.strip()
  38. if not safe:
  39. safe = "base"
  40. return safe.replace("`", "``")
  41. return sanitize
  42. sanitize = get_actual_sanitization_logic()
  43. class TestWorkspaceLabelSanitization:
  44. """Test suite for _get_workspace_label() sanitization logic."""
  45. def assert_logic(self, workspace: str, expected: str):
  46. """Helper to assert sanitization logic."""
  47. assert sanitize(workspace) == expected
  48. # --- Normal inputs ---
  49. def test_alphanumeric_unchanged(self):
  50. """Pure alphanumeric workspace names should pass through unchanged."""
  51. self.assert_logic("myworkspace", "myworkspace")
  52. def test_alphanumeric_with_underscore(self):
  53. """Underscores are allowed and should remain."""
  54. self.assert_logic("my_workspace_1", "my_workspace_1")
  55. def test_uppercase_preserved(self):
  56. """Case should be preserved."""
  57. self.assert_logic("MyWorkSpace", "MyWorkSpace")
  58. def test_numeric_only(self):
  59. """Numeric-only workspaces are valid."""
  60. self.assert_logic("12345", "12345")
  61. # --- Special characters preserved (unlike PostgreSQL regex stripping) ---
  62. def test_spaces_preserved(self):
  63. """Spaces in workspace names should be preserved."""
  64. self.assert_logic("my workspace", "my workspace")
  65. def test_hyphens_preserved(self):
  66. """Hyphens should be preserved (solves collision issue)."""
  67. self.assert_logic("my-workspace", "my-workspace")
  68. def test_dots_preserved(self):
  69. """Dots should be preserved."""
  70. self.assert_logic("my.workspace", "my.workspace")
  71. def test_mixed_special_chars_preserved(self):
  72. """Multiple different special characters should be preserved."""
  73. self.assert_logic("a-b.c d@e!f", "a-b.c d@e!f")
  74. # --- Cypher injection payloads ---
  75. def test_cypher_injection_backtick_escaped(self):
  76. """Backtick injection attempt should be neutralized by doubling backticks."""
  77. malicious = "test`}) MATCH (n) DETACH DELETE n //"
  78. # The single backtick should become a double backtick
  79. expected = "test``}) MATCH (n) DETACH DELETE n //"
  80. self.assert_logic(malicious, expected)
  81. def test_cypher_injection_multiple_backticks(self):
  82. """Multiple backticks should all be escaped."""
  83. malicious = "`DROP`DATABASE`"
  84. expected = "``DROP``DATABASE``"
  85. self.assert_logic(malicious, expected)
  86. def test_cypher_injection_curly_braces_preserved(self):
  87. """Curly brace injection is harmless when enclosed in backticks, so preserved."""
  88. malicious = "test}) RETURN 1 //"
  89. self.assert_logic(malicious, malicious)
  90. def test_cypher_injection_semicolon_preserved(self):
  91. """Semicolon injection is harmless when enclosed in backticks, so preserved."""
  92. malicious = "test; DROP DATABASE neo4j"
  93. self.assert_logic(malicious, malicious)
  94. def test_cypher_injection_quotes_preserved(self):
  95. """Quote injection is harmless when enclosed in backticks, so preserved."""
  96. malicious = 'test" OR 1=1 //'
  97. self.assert_logic(malicious, malicious)
  98. # --- Empty / whitespace fallback ---
  99. def test_empty_string_fallback(self):
  100. """Empty workspace should fall back to 'base'."""
  101. self.assert_logic("", "base")
  102. def test_whitespace_only_fallback(self):
  103. """Whitespace-only workspace should fall back to 'base'."""
  104. self.assert_logic(" ", "base")
  105. def test_special_chars_only_preserved(self):
  106. """Workspace with only special characters should be preserved."""
  107. self.assert_logic("---", "---")
  108. # --- Edge cases ---
  109. def test_leading_trailing_whitespace_stripped(self):
  110. """Leading/trailing whitespace should be stripped before sanitization."""
  111. self.assert_logic(" myworkspace ", "myworkspace")
  112. def test_unicode_characters_preserved(self):
  113. """Non-ASCII/Chinese characters should be preserved."""
  114. self.assert_logic("工作区_test", "工作区_test")
  115. def test_very_long_workspace(self):
  116. """Very long workspace names should still be sanitized correctly."""
  117. long_name = "a" * 1000 + "`"
  118. expected = "a" * 1000 + "``"
  119. self.assert_logic(long_name, expected)
  120. def test_single_underscore(self):
  121. """Single underscore should be valid."""
  122. self.assert_logic("_", "_")
  123. def test_result_always_escapes_backticks(self):
  124. """Parametric check: any output must not contain unescaped single backticks."""
  125. dangerous_inputs = [
  126. "normal",
  127. "with spaces",
  128. "with-dashes",
  129. "with.dots",
  130. "`) DETACH DELETE n //",
  131. "'; DROP TABLE users; --",
  132. "test\nMATCH (n) DELETE n",
  133. "\t\ttabs",
  134. "emoji🚀test",
  135. ]
  136. for inp in dangerous_inputs:
  137. result = sanitize(inp)
  138. backtick_sequences = re.findall(r"`+", result)
  139. for seq in backtick_sequences:
  140. # Any sequence of backticks should have an EVEN length because each ` becomes ``
  141. assert (
  142. len(seq) % 2 == 0
  143. ), f"Unescaped backtick found in result '{result}' for input '{inp}'"