_vision_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. """Shared image-input normalization for LLM bindings.
  2. All LLM bindings accept a unified ``image_inputs`` keyword parameter. Each
  3. element may be:
  4. - a raw base64 string (the MIME type is inferred via ``imghdr`` / magic bytes,
  5. defaulting to ``image/png``);
  6. - a data URL of the form ``data:<mime>;base64,<payload>``;
  7. - a dict with keys ``base64`` (required) and optional ``mime_type``,
  8. ``source_id``, ``source_file``, ``modality``, ``doc_id``.
  9. The provider-specific binding code converts the normalized result to its own
  10. content-block format. The VLM pipeline uses :func:`image_cache_metadata` for
  11. cache-key inputs (deliberately excluding ``source_id`` / ``source_file`` so the
  12. same image at different filenames still hits the same entry) and
  13. :func:`image_audit_metadata` for the human-readable ``original_prompt`` audit
  14. block.
  15. """
  16. from __future__ import annotations
  17. import base64
  18. import hashlib
  19. import re
  20. import struct
  21. from dataclasses import dataclass
  22. from pathlib import Path
  23. from typing import Any
  24. DATA_URL_RE = re.compile(
  25. r"^data:(?P<mime>[\w./+-]+);base64,(?P<data>[A-Za-z0-9+/=\s]+)$"
  26. )
  27. _PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
  28. _JPEG_SIGNATURE = b"\xff\xd8\xff"
  29. _GIF_SIGNATURES = (b"GIF87a", b"GIF89a")
  30. _WEBP_RIFF = b"RIFF"
  31. _WEBP_TAG = b"WEBP"
  32. @dataclass(frozen=True)
  33. class NormalizedImage:
  34. index: int
  35. raw_bytes: bytes
  36. mime_type: str
  37. sha256: str
  38. base64_str: str
  39. source_id: str | None
  40. source_file: str | None
  41. modality: str | None
  42. doc_id: str | None
  43. # Pixel dimensions parsed from the raster header (None when the format
  44. # is recognized but dimensions could not be extracted).
  45. width: int | None = None
  46. height: int | None = None
  47. def _detect_mime(raw: bytes) -> str:
  48. if raw.startswith(_PNG_SIGNATURE):
  49. return "image/png"
  50. if raw.startswith(_JPEG_SIGNATURE):
  51. return "image/jpeg"
  52. if any(raw.startswith(sig) for sig in _GIF_SIGNATURES):
  53. return "image/gif"
  54. if len(raw) >= 12 and raw[0:4] == _WEBP_RIFF and raw[8:12] == _WEBP_TAG:
  55. return "image/webp"
  56. return "image/png"
  57. def _decode_base64(data: str) -> bytes:
  58. cleaned = re.sub(r"\s+", "", data)
  59. try:
  60. return base64.b64decode(cleaned, validate=True)
  61. except (base64.binascii.Error, ValueError) as exc:
  62. raise ValueError(f"invalid base64 image data: {exc}") from exc
  63. def _coerce_item(item: Any) -> dict[str, Any]:
  64. if isinstance(item, str):
  65. match = DATA_URL_RE.match(item.strip())
  66. if match:
  67. return {"base64": match.group("data"), "mime_type": match.group("mime")}
  68. return {"base64": item}
  69. if isinstance(item, dict):
  70. if "base64" not in item:
  71. raise ValueError("image_inputs dict element must contain a 'base64' key")
  72. return item
  73. raise TypeError(
  74. f"image_inputs element must be str or dict, got {type(item).__name__}"
  75. )
  76. def normalize_image_inputs(
  77. image_inputs: list[Any] | None,
  78. ) -> list[NormalizedImage]:
  79. """Normalize the unified ``image_inputs`` parameter.
  80. Returns an empty list when ``image_inputs`` is falsy, so callers can do a
  81. plain ``if normalized:`` check.
  82. """
  83. if not image_inputs:
  84. return []
  85. result: list[NormalizedImage] = []
  86. for idx, raw_item in enumerate(image_inputs):
  87. item = _coerce_item(raw_item)
  88. raw_bytes = _decode_base64(item["base64"])
  89. if not raw_bytes:
  90. raise ValueError(f"image_inputs[{idx}] decoded to empty bytes")
  91. mime_type = item.get("mime_type") or _detect_mime(raw_bytes)
  92. sha = hashlib.sha256(raw_bytes).hexdigest()
  93. clean_b64 = base64.b64encode(raw_bytes).decode("ascii")
  94. dims = _dimensions_from_bytes(raw_bytes)
  95. width, height = (dims[0], dims[1]) if dims else (None, None)
  96. result.append(
  97. NormalizedImage(
  98. index=idx,
  99. raw_bytes=raw_bytes,
  100. mime_type=mime_type,
  101. sha256=sha,
  102. base64_str=clean_b64,
  103. source_id=item.get("source_id"),
  104. source_file=item.get("source_file"),
  105. modality=item.get("modality"),
  106. doc_id=item.get("doc_id"),
  107. width=width,
  108. height=height,
  109. )
  110. )
  111. return result
  112. def image_cache_metadata(images: list[NormalizedImage]) -> list[dict[str, Any]]:
  113. """Return cache-key-safe image metadata (no source identifiers).
  114. Includes ``width`` / ``height`` so the cache key reflects the full
  115. image digest the design contract specifies (mime, sha256, bytes,
  116. width, height). The sha256 alone is sufficient for identity, but
  117. surfacing dimensions matches the documented audit shape and gives
  118. diagnostics a one-line "what was sent" without re-decoding.
  119. """
  120. return [
  121. {
  122. "index": img.index,
  123. "mime_type": img.mime_type,
  124. "sha256": img.sha256,
  125. "bytes": len(img.raw_bytes),
  126. "width": img.width,
  127. "height": img.height,
  128. }
  129. for img in images
  130. ]
  131. def image_audit_metadata(images: list[NormalizedImage]) -> list[dict[str, Any]]:
  132. """Return audit metadata suitable for the ``original_prompt`` block.
  133. Never includes the raw base64 payload — only digests and source pointers.
  134. """
  135. return [
  136. {
  137. "index": img.index,
  138. "mime_type": img.mime_type,
  139. "sha256": img.sha256,
  140. "bytes": len(img.raw_bytes),
  141. "width": img.width,
  142. "height": img.height,
  143. "source_id": img.source_id,
  144. "source_file": img.source_file,
  145. "modality": img.modality,
  146. "doc_id": img.doc_id,
  147. }
  148. for img in images
  149. ]
  150. def _read_png_dimensions(data: bytes) -> tuple[int, int] | None:
  151. # IHDR is the first chunk; width/height are big-endian uint32 at offsets
  152. # 16/20 (8-byte signature + 4 length + 4 "IHDR" + 4 width + 4 height).
  153. if len(data) < 24 or not data.startswith(_PNG_SIGNATURE):
  154. return None
  155. width, height = struct.unpack(">II", data[16:24])
  156. return width, height
  157. def _read_gif_dimensions(data: bytes) -> tuple[int, int] | None:
  158. # Logical screen descriptor: width/height are little-endian uint16 at
  159. # offsets 6/8.
  160. if len(data) < 10 or not any(data.startswith(sig) for sig in _GIF_SIGNATURES):
  161. return None
  162. width, height = struct.unpack("<HH", data[6:10])
  163. return width, height
  164. def _read_jpeg_dimensions(data: bytes) -> tuple[int, int] | None:
  165. # Scan for a Start-Of-Frame marker (SOF0 / SOF2 / etc.). Skip segments by
  166. # their length field. We deliberately accept any SOF variant the codec
  167. # might emit rather than enumerating each one.
  168. if len(data) < 4 or not data.startswith(_JPEG_SIGNATURE):
  169. return None
  170. i = 2
  171. n = len(data)
  172. while i < n:
  173. if data[i] != 0xFF:
  174. return None
  175. # Skip fill bytes.
  176. while i < n and data[i] == 0xFF:
  177. i += 1
  178. if i >= n:
  179. return None
  180. marker = data[i]
  181. i += 1
  182. # Standalone markers without a length field.
  183. if marker in (0xD8, 0xD9) or 0xD0 <= marker <= 0xD7:
  184. continue
  185. if i + 2 > n:
  186. return None
  187. segment_len = struct.unpack(">H", data[i : i + 2])[0]
  188. if segment_len < 2 or i + segment_len > n:
  189. return None
  190. # SOF0..SOF15 except 0xC4 (DHT), 0xC8 (JPG reserved), 0xCC (DAC).
  191. if 0xC0 <= marker <= 0xCF and marker not in (0xC4, 0xC8, 0xCC):
  192. # SOF payload: precision(1) + height(2) + width(2) + …
  193. if i + 7 > n:
  194. return None
  195. height, width = struct.unpack(">HH", data[i + 3 : i + 7])
  196. return width, height
  197. i += segment_len
  198. return None
  199. def _read_webp_dimensions(data: bytes) -> tuple[int, int] | None:
  200. if len(data) < 30 or data[0:4] != _WEBP_RIFF or data[8:12] != _WEBP_TAG:
  201. return None
  202. chunk_type = data[12:16]
  203. if chunk_type == b"VP8 ":
  204. # Lossy: 3-byte tag + 3-byte sync code at offset 23, then 4 bytes
  205. # holding 14-bit width / 14-bit height in little-endian halves.
  206. if len(data) < 30:
  207. return None
  208. width = struct.unpack("<H", data[26:28])[0] & 0x3FFF
  209. height = struct.unpack("<H", data[28:30])[0] & 0x3FFF
  210. return width, height
  211. if chunk_type == b"VP8L":
  212. # Lossless: signature(0x2F) + 4 bytes encoding 14-bit width-1 / 14-bit
  213. # height-1 starting at offset 21.
  214. if len(data) < 25 or data[20] != 0x2F:
  215. return None
  216. b0, b1, b2, b3 = data[21], data[22], data[23], data[24]
  217. width = ((b1 & 0x3F) << 8 | b0) + 1
  218. height = ((b3 & 0x0F) << 10 | b2 << 2 | (b1 & 0xC0) >> 6) + 1
  219. return width, height
  220. if chunk_type == b"VP8X":
  221. # Extended: 3 bytes width-1 / 3 bytes height-1, little-endian, at
  222. # offsets 24/27.
  223. if len(data) < 30:
  224. return None
  225. width = (data[24] | data[25] << 8 | data[26] << 16) + 1
  226. height = (data[27] | data[28] << 8 | data[29] << 16) + 1
  227. return width, height
  228. return None
  229. def read_image_dimensions(path: Path) -> tuple[int, int] | None:
  230. """Return ``(width, height)`` for a raster image, or ``None`` if unknown.
  231. Reads only the file header — no Pillow dependency. Supports PNG, JPEG,
  232. GIF and WebP (VP8 / VP8L / VP8X). Returns ``None`` for unsupported
  233. formats and on any I/O or parse error so callers can fall back to a
  234. skipped/failure decision without raising.
  235. """
  236. try:
  237. with open(path, "rb") as fh:
  238. header = fh.read(64 * 1024)
  239. except OSError:
  240. return None
  241. return _dimensions_from_bytes(header)
  242. def _dimensions_from_bytes(data: bytes) -> tuple[int, int] | None:
  243. """Run the four header readers against a byte buffer.
  244. Shared between the file-path entry point (:func:`read_image_dimensions`)
  245. and :func:`normalize_image_inputs`, which receives raster payloads
  246. decoded from the unified ``image_inputs`` parameter.
  247. """
  248. if not data:
  249. return None
  250. for reader in (
  251. _read_png_dimensions,
  252. _read_gif_dimensions,
  253. _read_jpeg_dimensions,
  254. _read_webp_dimensions,
  255. ):
  256. try:
  257. dims = reader(data)
  258. except (struct.error, IndexError, ValueError):
  259. continue
  260. if dims:
  261. return dims
  262. return None