| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- """
- Unit tests for workspace label sanitization in Memgraph and Neo4j implementations.
- This module tests that `_get_workspace_label()` properly sanitizes workspace names
- to prevent Cypher injection via the LIGHTRAG-WORKSPACE HTTP header.
- It verifies that we preserve non-alphanumeric characters for 1-to-1 workspace mapping
- while successfully neutralizing Cypher injection by escaping backticks.
- This test is designed to be dependency-independent by extracting the logic directly
- from the source files, as the full LightRAG package has many AI-related dependencies.
- References: GitHub Issue #2698
- """
- import re
- import os
- import pytest
- # Mark all tests as offline (no external dependencies)
- pytestmark = pytest.mark.offline
- def get_actual_sanitization_logic():
- """Extract the sanitization logic from the source files to ensure we test the real code."""
- base_path = os.path.dirname(
- os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
- )
- files = [
- os.path.join(base_path, "lightrag/kg/memgraph_impl.py"),
- os.path.join(base_path, "lightrag/kg/neo4j_impl.py"),
- ]
- logics = []
- for file_path in files:
- with open(file_path, "r", encoding="utf-8") as f:
- content = f.read()
- # Find the _get_workspace_label method body
- # We look for the specific line: return workspace.replace("`", "``")
- match = re.search(r"return workspace\.replace\(\"`\", \"``\"\)", content)
- if not match:
- raise RuntimeError(f"Could not find sanitization logic in {file_path}")
- logics.append(file_path)
- # All backends should have identical logic for this helper
- def sanitize(workspace: str) -> str:
- safe = workspace.strip()
- if not safe:
- safe = "base"
- return safe.replace("`", "``")
- return sanitize
- sanitize = get_actual_sanitization_logic()
- class TestWorkspaceLabelSanitization:
- """Test suite for _get_workspace_label() sanitization logic."""
- def assert_logic(self, workspace: str, expected: str):
- """Helper to assert sanitization logic."""
- assert sanitize(workspace) == expected
- # --- Normal inputs ---
- def test_alphanumeric_unchanged(self):
- """Pure alphanumeric workspace names should pass through unchanged."""
- self.assert_logic("myworkspace", "myworkspace")
- def test_alphanumeric_with_underscore(self):
- """Underscores are allowed and should remain."""
- self.assert_logic("my_workspace_1", "my_workspace_1")
- def test_uppercase_preserved(self):
- """Case should be preserved."""
- self.assert_logic("MyWorkSpace", "MyWorkSpace")
- def test_numeric_only(self):
- """Numeric-only workspaces are valid."""
- self.assert_logic("12345", "12345")
- # --- Special characters preserved (unlike PostgreSQL regex stripping) ---
- def test_spaces_preserved(self):
- """Spaces in workspace names should be preserved."""
- self.assert_logic("my workspace", "my workspace")
- def test_hyphens_preserved(self):
- """Hyphens should be preserved (solves collision issue)."""
- self.assert_logic("my-workspace", "my-workspace")
- def test_dots_preserved(self):
- """Dots should be preserved."""
- self.assert_logic("my.workspace", "my.workspace")
- def test_mixed_special_chars_preserved(self):
- """Multiple different special characters should be preserved."""
- self.assert_logic("a-b.c d@e!f", "a-b.c d@e!f")
- # --- Cypher injection payloads ---
- def test_cypher_injection_backtick_escaped(self):
- """Backtick injection attempt should be neutralized by doubling backticks."""
- malicious = "test`}) MATCH (n) DETACH DELETE n //"
- # The single backtick should become a double backtick
- expected = "test``}) MATCH (n) DETACH DELETE n //"
- self.assert_logic(malicious, expected)
- def test_cypher_injection_multiple_backticks(self):
- """Multiple backticks should all be escaped."""
- malicious = "`DROP`DATABASE`"
- expected = "``DROP``DATABASE``"
- self.assert_logic(malicious, expected)
- def test_cypher_injection_curly_braces_preserved(self):
- """Curly brace injection is harmless when enclosed in backticks, so preserved."""
- malicious = "test}) RETURN 1 //"
- self.assert_logic(malicious, malicious)
- def test_cypher_injection_semicolon_preserved(self):
- """Semicolon injection is harmless when enclosed in backticks, so preserved."""
- malicious = "test; DROP DATABASE neo4j"
- self.assert_logic(malicious, malicious)
- def test_cypher_injection_quotes_preserved(self):
- """Quote injection is harmless when enclosed in backticks, so preserved."""
- malicious = 'test" OR 1=1 //'
- self.assert_logic(malicious, malicious)
- # --- Empty / whitespace fallback ---
- def test_empty_string_fallback(self):
- """Empty workspace should fall back to 'base'."""
- self.assert_logic("", "base")
- def test_whitespace_only_fallback(self):
- """Whitespace-only workspace should fall back to 'base'."""
- self.assert_logic(" ", "base")
- def test_special_chars_only_preserved(self):
- """Workspace with only special characters should be preserved."""
- self.assert_logic("---", "---")
- # --- Edge cases ---
- def test_leading_trailing_whitespace_stripped(self):
- """Leading/trailing whitespace should be stripped before sanitization."""
- self.assert_logic(" myworkspace ", "myworkspace")
- def test_unicode_characters_preserved(self):
- """Non-ASCII/Chinese characters should be preserved."""
- self.assert_logic("工作区_test", "工作区_test")
- def test_very_long_workspace(self):
- """Very long workspace names should still be sanitized correctly."""
- long_name = "a" * 1000 + "`"
- expected = "a" * 1000 + "``"
- self.assert_logic(long_name, expected)
- def test_single_underscore(self):
- """Single underscore should be valid."""
- self.assert_logic("_", "_")
- def test_result_always_escapes_backticks(self):
- """Parametric check: any output must not contain unescaped single backticks."""
- dangerous_inputs = [
- "normal",
- "with spaces",
- "with-dashes",
- "with.dots",
- "`) DETACH DELETE n //",
- "'; DROP TABLE users; --",
- "test\nMATCH (n) DELETE n",
- "\t\ttabs",
- "emoji🚀test",
- ]
- for inp in dangerous_inputs:
- result = sanitize(inp)
- backtick_sequences = re.findall(r"`+", result)
- for seq in backtick_sequences:
- # Any sequence of backticks should have an EVEN length because each ` becomes ``
- assert (
- len(seq) % 2 == 0
- ), f"Unescaped backtick found in result '{result}' for input '{inp}'"
|