utils.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. from urllib.parse import urlparse
  2. import phonenumbers
  3. import ipaddress
  4. from phonenumbers import NumberParseException
  5. from pydantic import TypeAdapter, BaseModel
  6. from urllib.parse import urlparse
  7. import re
  8. import ssl
  9. import socket
  10. from typing import Dict, Any, List, Type
  11. import inspect
  12. from typing import Any, Dict, Type
  13. from pydantic import BaseModel, TypeAdapter
  14. def is_valid_ip(address: str) -> bool:
  15. try:
  16. ipaddress.ip_address(address)
  17. return True
  18. except ValueError:
  19. return False
  20. def is_valid_username(username: str) -> bool:
  21. if not re.match(r"^[a-zA-Z0-9_-]{3,30}$", username):
  22. return False
  23. return True
  24. def is_valid_email(email: str) -> bool:
  25. if not re.match(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$", email):
  26. return False
  27. return True
  28. def is_valid_domain(url_or_domain: str) -> str:
  29. try:
  30. parsed = urlparse(
  31. url_or_domain if "://" in url_or_domain else "http://" + url_or_domain
  32. )
  33. hostname = parsed.hostname or url_or_domain
  34. if not hostname or "." not in hostname:
  35. return False
  36. if not re.match(r"^[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$", hostname):
  37. return False
  38. return True
  39. except Exception as e:
  40. return False
  41. def is_root_domain(domain: str) -> bool:
  42. """
  43. Determine if a domain is a root domain or subdomain.
  44. Args:
  45. domain: The domain string to check
  46. Returns:
  47. True if it's a root domain (e.g., example.com), False if it's a subdomain (e.g., sub.example.com)
  48. """
  49. try:
  50. # Remove protocol if present
  51. if "://" in domain:
  52. parsed = urlparse(domain)
  53. domain = parsed.hostname or domain
  54. # Split by dots
  55. parts = domain.split(".")
  56. # Handle common country code TLDs that have 2 parts (e.g., .co.uk, .com.au, .org.uk)
  57. common_cc_tlds = [
  58. ".co.uk",
  59. ".com.au",
  60. ".org.uk",
  61. ".net.uk",
  62. ".gov.uk",
  63. ".ac.uk",
  64. ".co.nz",
  65. ".com.sg",
  66. ".co.jp",
  67. ".co.kr",
  68. ".com.br",
  69. ".com.mx",
  70. ]
  71. # Check if the domain ends with a common country code TLD
  72. for cc_tld in common_cc_tlds:
  73. if domain.endswith(cc_tld):
  74. # For country code TLDs, we need exactly 3 parts (e.g., example.co.uk)
  75. return len(parts) == 3
  76. # For regular TLDs, a root domain has 2 parts (e.g., example.com)
  77. # A subdomain has 3 or more parts (e.g., sub.example.com, www.sub.example.com)
  78. return len(parts) == 2
  79. except Exception:
  80. # If we can't parse it, assume it's not a root domain
  81. return False
  82. def is_valid_number(phone: str, region: str = "FR") -> None:
  83. """
  84. Validates a phone number. Raises InvalidPhoneNumberError if invalid.
  85. - `region` should be ISO 3166-1 alpha-2 country code (e.g., 'FR' for France)
  86. """
  87. try:
  88. parsed = phonenumbers.parse(phone, region)
  89. if not phonenumbers.is_valid_number(parsed):
  90. return False
  91. except NumberParseException:
  92. return False
  93. def parse_asn(asn: str) -> int:
  94. if not is_valid_asn(asn):
  95. raise ValueError(f"Invalid ASN format: {asn}")
  96. return int(re.sub(r"(?i)^AS", "", asn))
  97. def is_valid_asn(asn: str) -> bool:
  98. if not re.fullmatch(r"(AS)?\d+", asn, re.IGNORECASE):
  99. return False
  100. asn_num = int(re.sub(r"(?i)^AS", "", asn))
  101. return 0 <= asn_num <= 4294967295
  102. def resolve_type(details: dict, schema_context: dict = None) -> str:
  103. if "anyOf" in details:
  104. types = []
  105. for option in details["anyOf"]:
  106. if "$ref" in option:
  107. ref = option["$ref"].split("/")[-1]
  108. types.append(ref)
  109. elif option.get("type") == "array":
  110. # Handle array types within anyOf
  111. item_type = resolve_type(option.get("items", {}), schema_context)
  112. types.append(f"{item_type}[]")
  113. else:
  114. types.append(option.get("type", "unknown"))
  115. return " | ".join(types)
  116. if "type" in details:
  117. if details["type"] == "array":
  118. item_type = resolve_type(details.get("items", {}), schema_context)
  119. return f"{item_type}[]"
  120. return details["type"]
  121. # Handle $ref in array items or other contexts
  122. if "$ref" in details and schema_context:
  123. ref_path = details["$ref"]
  124. if ref_path.startswith("#/$defs/"):
  125. ref_name = ref_path.split("/")[-1]
  126. return ref_name
  127. return "any"
  128. def extract_input_schema_flow(model: Type[BaseModel]) -> Dict[str, Any]:
  129. adapter = TypeAdapter(model)
  130. schema = adapter.json_schema()
  131. # Use the main schema properties, not the $defs
  132. type_name = model.__name__
  133. details = schema
  134. return {
  135. "class_name": model.__name__,
  136. "name": model.__name__,
  137. "module": model.__module__,
  138. "description": inspect.cleandoc(model.__doc__ or ""),
  139. "outputs": {
  140. "type": type_name,
  141. "properties": [
  142. {"name": prop, "type": resolve_type(info, schema)}
  143. for prop, info in details.get("properties", {}).items()
  144. ],
  145. },
  146. "inputs": {"type": "", "properties": []},
  147. "type": "type",
  148. "category": model.__name__,
  149. }
  150. def get_domain_from_ssl(ip: str, port: int = 443) -> str | None:
  151. try:
  152. context = ssl.create_default_context()
  153. with socket.create_connection((ip, port), timeout=3) as sock:
  154. with context.wrap_socket(sock, server_hostname=ip) as ssock:
  155. cert = ssock.getpeercert()
  156. subject = cert.get("subject", [])
  157. for entry in subject:
  158. if entry[0][0] == "commonName":
  159. return entry[0][1]
  160. # Alternative: check subjectAltName
  161. san = cert.get("subjectAltName", [])
  162. for typ, val in san:
  163. if typ == "DNS":
  164. return val
  165. except Exception as e:
  166. print(f"SSL extraction failed for {ip}: {e}")
  167. return None
  168. def extract_enricher(enricher: Dict[str, Any]) -> Dict[str, Any]:
  169. nodes = enricher["nodes"]
  170. edges = enricher["edges"]
  171. input_node = next((node for node in nodes if node["data"]["type"] == "type"), None)
  172. if not input_node:
  173. raise ValueError("No input node found.")
  174. input_output = input_node["data"]["outputs"]
  175. node_lookup = {node["id"]: node for node in nodes}
  176. enrichers = []
  177. for edge in edges:
  178. target_id = edge["target"]
  179. source_handle = edge["sourceHandle"]
  180. target_handle = edge["targetHandle"]
  181. enricher_node = node_lookup.get(target_id)
  182. if enricher_node and enricher_node["data"]["type"] == "enricher":
  183. enrichers.append(
  184. {
  185. "enricher_name": enricher_node["data"]["name"],
  186. "module": enricher_node["data"]["module"],
  187. "input": source_handle,
  188. "output": target_handle,
  189. }
  190. )
  191. return {
  192. "input": {
  193. "name": input_node["data"]["name"],
  194. "outputs": input_output,
  195. },
  196. "enrichers": enrichers,
  197. "enricher_names": [enricher["enricher_name"] for enricher in enrichers],
  198. }
  199. def get_label_color(label: str) -> str:
  200. color_map = {"subdomain": "#A5ABB6", "domain": "#68BDF6", "default": "#A5ABB6"}
  201. return color_map.get(label, color_map["default"])
  202. def flatten(data_dict, prefix=""):
  203. """
  204. Flattens a dictionary to contain only Neo4j-compatible property values.
  205. Neo4j supports primitive types (string, number, boolean) and arrays of those types.
  206. Args:
  207. data_dict (dict): Dictionary to flatten
  208. Returns:
  209. dict: Flattened dictionary with only Neo4j-compatible values
  210. """
  211. flattened = {}
  212. if not isinstance(data_dict, dict):
  213. return flattened
  214. for key, value in data_dict.items():
  215. if value is None:
  216. continue
  217. if isinstance(value, (str, int, float, bool)) or (
  218. isinstance(value, list)
  219. and all(isinstance(item, (str, int, float, bool)) for item in value)
  220. ):
  221. key = f"{prefix}{key}"
  222. flattened[key] = value
  223. return flattened
  224. def get_inline_relationships(nodes: List[Any], edges: List[Any]) -> List[str]:
  225. """
  226. Get the inline relationships for a list of nodes and edges.
  227. """
  228. relationships = []
  229. for edge in edges:
  230. source = next((node for node in nodes if node["id"] == edge["source"]), None)
  231. target = next((node for node in nodes if node["id"] == edge["target"]), None)
  232. if source and target:
  233. relationships.append({"source": source, "edge": edge, "target": target})
  234. return relationships
  235. def to_json_serializable(obj):
  236. """Convert any object to a JSON-serializable format."""
  237. import json
  238. from pydantic import BaseModel
  239. try:
  240. # Test if already JSON serializable
  241. json.dumps(obj)
  242. return obj
  243. except (TypeError, ValueError):
  244. # Handle common cases
  245. if isinstance(obj, BaseModel):
  246. # Use mode='json' to ensure all Pydantic types are properly serialized
  247. return (
  248. obj.model_dump(mode="json")
  249. if hasattr(obj, "model_dump")
  250. else obj.dict()
  251. )
  252. elif isinstance(obj, (list, tuple)):
  253. return [to_json_serializable(item) for item in obj]
  254. elif isinstance(obj, dict):
  255. return {key: to_json_serializable(value) for key, value in obj.items()}
  256. else:
  257. # Convert anything else to string
  258. return str(obj)