test_vlm_cache_key.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. """Offline tests for the VLM cache-key invariants used by analyze_multimodal.
  2. These tests verify the hash inputs we feed into ``compute_args_hash`` actually
  3. deliver the contract documented in the LLM/VLM vision plan:
  4. - same prompt + same image content => cache HIT (identical args_hash)
  5. - same prompt + different image content => cache MISS (different args_hash)
  6. - same prompt + same image content under a different file path/source_id =>
  7. cache HIT (provenance is for audit only and must not affect the hash)
  8. - the audit blob written into ``original_prompt`` never embeds the raw base64
  9. payload, only digests and provenance pointers
  10. """
  11. from __future__ import annotations
  12. import base64
  13. import json
  14. from typing import Any
  15. import pytest
  16. from lightrag.llm._vision_utils import (
  17. image_audit_metadata,
  18. image_cache_metadata,
  19. normalize_image_inputs,
  20. )
  21. from lightrag.utils import (
  22. _serialize_cache_variant,
  23. compute_args_hash,
  24. get_llm_cache_identity,
  25. serialize_llm_cache_identity,
  26. )
  27. pytestmark = pytest.mark.offline
  28. PNG_A = (
  29. b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01"
  30. b"\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc\xf8"
  31. b"\xcf\xc0\x00\x00\x00\x03\x00\x01\x5c\xcc\xd9\x9e\x00\x00\x00\x00"
  32. b"IEND\xaeB`\x82"
  33. )
  34. PNG_B = PNG_A[:-12] + b"\x01" + PNG_A[-11:] # 1-byte tweak => different hash
  35. def _b64(raw: bytes) -> str:
  36. return base64.b64encode(raw).decode("ascii")
  37. def _hash_for(prompt: str, images: list[dict[str, Any]] | None) -> str:
  38. normalized = normalize_image_inputs(images) if images else []
  39. identity = get_llm_cache_identity({}, role="vlm")
  40. return compute_args_hash(
  41. prompt,
  42. "",
  43. "",
  44. serialize_llm_cache_identity(identity),
  45. _serialize_cache_variant({"type": "json_object"}),
  46. _serialize_cache_variant(image_cache_metadata(normalized)),
  47. )
  48. def test_same_prompt_same_image_yields_same_hash():
  49. h1 = _hash_for("describe", [{"base64": _b64(PNG_A)}])
  50. h2 = _hash_for("describe", [{"base64": _b64(PNG_A)}])
  51. assert h1 == h2
  52. def test_same_prompt_different_image_yields_different_hash():
  53. h1 = _hash_for("describe", [{"base64": _b64(PNG_A)}])
  54. h2 = _hash_for("describe", [{"base64": _b64(PNG_B)}])
  55. assert h1 != h2
  56. def test_same_image_different_source_file_still_hits():
  57. h1 = _hash_for(
  58. "describe",
  59. [
  60. {
  61. "base64": _b64(PNG_A),
  62. "source_id": "img-001",
  63. "source_file": "/path/a/img.png",
  64. "modality": "image",
  65. "doc_id": "doc-1",
  66. }
  67. ],
  68. )
  69. h2 = _hash_for(
  70. "describe",
  71. [
  72. {
  73. "base64": _b64(PNG_A),
  74. "source_id": "img-002",
  75. "source_file": "/different/elsewhere/copy.png",
  76. "modality": "image",
  77. "doc_id": "doc-2",
  78. }
  79. ],
  80. )
  81. assert h1 == h2
  82. def test_different_prompt_with_same_image_yields_different_hash():
  83. h1 = _hash_for("describe", [{"base64": _b64(PNG_A)}])
  84. h2 = _hash_for("describe in english", [{"base64": _b64(PNG_A)}])
  85. assert h1 != h2
  86. def test_image_present_vs_absent_yields_different_hash():
  87. h_text_only = _hash_for("describe", None)
  88. h_with_image = _hash_for("describe", [{"base64": _b64(PNG_A)}])
  89. assert h_text_only != h_with_image
  90. def test_audit_block_in_original_prompt_does_not_leak_raw_base64():
  91. """Mirrors how _analyze_item builds the cache-entry original_prompt."""
  92. normalized = normalize_image_inputs(
  93. [
  94. {
  95. "base64": _b64(PNG_A),
  96. "source_id": "img-001",
  97. "source_file": "/tmp/a.png",
  98. "modality": "image",
  99. "doc_id": "doc-1",
  100. }
  101. ]
  102. )
  103. audit_blob = image_audit_metadata(normalized)
  104. prompt = "describe"
  105. original_prompt = (
  106. prompt
  107. + f"\n<vlm_images>{json.dumps(audit_blob, ensure_ascii=False)}</vlm_images>"
  108. )
  109. assert "<vlm_images>" in original_prompt
  110. assert "</vlm_images>" in original_prompt
  111. # sha256 digest is present; raw base64 must not be.
  112. assert audit_blob[0]["sha256"] in original_prompt
  113. assert _b64(PNG_A) not in original_prompt
  114. def test_image_metadata_includes_width_height():
  115. """Design §5.2 contract: image digest metadata must surface
  116. width/height alongside mime/sha256/bytes so cache keys and audit blocks
  117. capture the full pixel footprint."""
  118. normalized = normalize_image_inputs([{"base64": _b64(PNG_A)}])
  119. cache_blob = image_cache_metadata(normalized)
  120. audit_blob = image_audit_metadata(normalized)
  121. assert len(cache_blob) == 1
  122. # 1x1 PNG fixture — dimensions are decodable from the IHDR chunk.
  123. assert cache_blob[0]["width"] == 1
  124. assert cache_blob[0]["height"] == 1
  125. assert audit_blob[0]["width"] == 1
  126. assert audit_blob[0]["height"] == 1
  127. def test_image_dimensions_change_changes_cache_key():
  128. """Two PNGs with the same pixel byte payload but different declared
  129. dimensions still differ at the byte level and therefore must hash to
  130. distinct args_hashes — the width/height fields in cache metadata
  131. document the difference without being the sole identity source."""
  132. # Build a 32x16 PNG and compare it against the 1x1 PNG_A.
  133. import struct
  134. import zlib
  135. sig = b"\x89PNG\r\n\x1a\n"
  136. ihdr_payload = struct.pack(">II", 32, 16) + b"\x08\x06\x00\x00\x00"
  137. ihdr_crc = zlib.crc32(b"IHDR" + ihdr_payload).to_bytes(4, "big")
  138. ihdr = struct.pack(">I", len(ihdr_payload)) + b"IHDR" + ihdr_payload + ihdr_crc
  139. idat_payload = b"\x00" * (32 * 16 * 4 + 16)
  140. idat_compressed = zlib.compress(idat_payload)
  141. idat_crc = zlib.crc32(b"IDAT" + idat_compressed).to_bytes(4, "big")
  142. idat = (
  143. struct.pack(">I", len(idat_compressed)) + b"IDAT" + idat_compressed + idat_crc
  144. )
  145. iend = b"\x00\x00\x00\x00IEND\xaeB`\x82"
  146. big_png = sig + ihdr + idat + iend
  147. normalized_small = normalize_image_inputs([{"base64": _b64(PNG_A)}])
  148. normalized_big = normalize_image_inputs([{"base64": _b64(big_png)}])
  149. assert image_cache_metadata(normalized_small)[0]["width"] == 1
  150. assert image_cache_metadata(normalized_big)[0]["width"] == 32
  151. assert image_cache_metadata(normalized_big)[0]["height"] == 16