test_vault.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. """Tests for the Vault class and secret management."""
  2. import pytest
  3. import uuid
  4. import os
  5. from unittest.mock import Mock, MagicMock
  6. from flowsint_core.core.vault import Vault
  7. from flowsint_core.core.models import Key
  8. @pytest.fixture
  9. def mock_db():
  10. """Create a mock database session."""
  11. return Mock()
  12. @pytest.fixture
  13. def owner_id():
  14. """Create a test owner ID."""
  15. return uuid.uuid4()
  16. @pytest.fixture
  17. def vault(mock_db, owner_id):
  18. """Create a Vault instance with mocked database."""
  19. return Vault(db=mock_db, owner_id=owner_id)
  20. @pytest.fixture(autouse=True)
  21. def mock_master_key(monkeypatch):
  22. """Mock the master key environment variable."""
  23. # Use a base64-encoded 32-byte key
  24. test_key = "base64:qnHTmwYb+uoygIw9MsRMY22vS5YPchY+QOi/E79GAvM="
  25. monkeypatch.setenv("MASTER_VAULT_KEY_V1", test_key)
  26. class TestVaultInitialization:
  27. """Tests for Vault initialization."""
  28. def test_vault_requires_owner_id(self, mock_db):
  29. """Test that Vault requires an owner_id."""
  30. with pytest.raises(ValueError, match="owner_id is required"):
  31. Vault(db=mock_db, owner_id=None)
  32. def test_vault_initialization_success(self, mock_db, owner_id):
  33. """Test successful Vault initialization."""
  34. vault = Vault(db=mock_db, owner_id=owner_id)
  35. assert vault.db == mock_db
  36. assert vault.owner_id == str(owner_id)
  37. assert vault.version == "V1"
  38. class TestVaultSetSecret:
  39. """Tests for Vault.set_secret() method."""
  40. def test_set_secret_creates_key(self, vault, mock_db):
  41. """Test that set_secret creates a new Key in the database."""
  42. vault_ref = "TEST_API_KEY"
  43. plain_key = "my-secret-api-key-12345"
  44. vault.set_secret(vault_ref, plain_key)
  45. # Verify that db.add, db.commit, and db.refresh were called
  46. assert mock_db.add.called
  47. assert mock_db.commit.called
  48. assert mock_db.refresh.called
  49. # Get the Key object that was added
  50. added_key = mock_db.add.call_args[0][0]
  51. assert isinstance(added_key, Key)
  52. assert added_key.name == vault_ref
  53. assert added_key.owner_id == vault.owner_id
  54. assert added_key.key_version == "V1"
  55. assert added_key.iv is not None
  56. assert added_key.salt is not None
  57. assert added_key.ciphertext is not None
  58. def test_set_secret_encrypts_data(self, vault, mock_db):
  59. """Test that set_secret properly encrypts the secret."""
  60. vault_ref = "TEST_API_KEY"
  61. plain_key = "my-secret-api-key-12345"
  62. vault.set_secret(vault_ref, plain_key)
  63. added_key = mock_db.add.call_args[0][0]
  64. # Ciphertext should not contain the plaintext
  65. assert plain_key.encode() not in added_key.ciphertext
  66. # IV and salt should be different lengths (12 and 16 bytes)
  67. assert len(added_key.iv) == 12
  68. assert len(added_key.salt) == 16
  69. class TestVaultGetSecret:
  70. """Tests for Vault.get_secret() method."""
  71. def test_get_secret_by_name_found(self, vault, mock_db, owner_id):
  72. """Test getting a secret by name when it exists."""
  73. vault_ref = "TEST_API_KEY"
  74. plain_key = "my-secret-api-key-12345"
  75. # Set a secret first to get encrypted data
  76. real_vault = Vault(db=MagicMock(), owner_id=owner_id)
  77. encrypted_data = real_vault._encrypt_key(plain_key)
  78. # Create a mock Key object
  79. mock_key = Mock()
  80. mock_key.name = vault_ref
  81. mock_key.id = uuid.uuid4()
  82. mock_key.owner_id = str(owner_id)
  83. mock_key.salt = encrypted_data["salt"]
  84. mock_key.iv = encrypted_data["iv"]
  85. mock_key.ciphertext = encrypted_data["ciphertext"]
  86. # Mock the database query
  87. mock_result = Mock()
  88. mock_result.scalars().first.return_value = mock_key
  89. mock_db.execute.return_value = mock_result
  90. # Get the secret
  91. result = vault.get_secret(vault_ref)
  92. assert result == plain_key
  93. assert mock_db.execute.called
  94. def test_get_secret_by_uuid_found(self, vault, mock_db, owner_id):
  95. """Test getting a secret by UUID when it exists."""
  96. key_id = uuid.uuid4()
  97. plain_key = "my-secret-api-key-12345"
  98. # Set a secret first to get encrypted data
  99. real_vault = Vault(db=MagicMock(), owner_id=owner_id)
  100. encrypted_data = real_vault._encrypt_key(plain_key)
  101. # Create a mock Key object
  102. mock_key = Mock()
  103. mock_key.name = "TEST_API_KEY"
  104. mock_key.id = key_id
  105. mock_key.owner_id = str(owner_id)
  106. mock_key.salt = encrypted_data["salt"]
  107. mock_key.iv = encrypted_data["iv"]
  108. mock_key.ciphertext = encrypted_data["ciphertext"]
  109. # Mock the database query
  110. mock_result = Mock()
  111. mock_result.scalars().first.return_value = mock_key
  112. mock_db.execute.return_value = mock_result
  113. # Get the secret by UUID
  114. result = vault.get_secret(str(key_id))
  115. assert result == plain_key
  116. assert mock_db.execute.called
  117. def test_get_secret_not_found(self, vault, mock_db):
  118. """Test getting a secret that doesn't exist."""
  119. vault_ref = "NONEXISTENT_KEY"
  120. # Mock the database query to return None
  121. mock_result = Mock()
  122. mock_result.scalars().first.return_value = None
  123. mock_db.execute.return_value = mock_result
  124. result = vault.get_secret(vault_ref)
  125. assert result is None
  126. def test_get_secret_wrong_owner(self, vault, mock_db):
  127. """Test that secrets from other owners cannot be accessed."""
  128. vault_ref = "TEST_API_KEY"
  129. # Mock the database query to return None (no key found for this owner)
  130. mock_result = Mock()
  131. mock_result.scalars().first.return_value = None
  132. mock_db.execute.return_value = mock_result
  133. result = vault.get_secret(vault_ref)
  134. assert result is None
  135. class TestVaultEncryptionDecryption:
  136. """Tests for encryption and decryption methods."""
  137. def test_encrypt_decrypt_roundtrip(self, vault):
  138. """Test that encryption and decryption work correctly."""
  139. plaintext = "my-secret-api-key-12345"
  140. # Encrypt
  141. encrypted_data = vault._encrypt_key(plaintext)
  142. assert "ciphertext" in encrypted_data
  143. assert "iv" in encrypted_data
  144. assert "salt" in encrypted_data
  145. assert plaintext.encode() not in encrypted_data["ciphertext"]
  146. # Decrypt
  147. decrypted = vault._decrypt_key(encrypted_data)
  148. assert decrypted == plaintext
  149. def test_different_salts_produce_different_ciphertexts(self, vault):
  150. """Test that the same plaintext with different salts produces different ciphertexts."""
  151. plaintext = "my-secret-api-key-12345"
  152. encrypted1 = vault._encrypt_key(plaintext)
  153. encrypted2 = vault._encrypt_key(plaintext)
  154. # Different salts and IVs
  155. assert encrypted1["salt"] != encrypted2["salt"]
  156. assert encrypted1["iv"] != encrypted2["iv"]
  157. # Different ciphertexts
  158. assert encrypted1["ciphertext"] != encrypted2["ciphertext"]
  159. def test_master_key_derivation(self, vault):
  160. """Test that master key is properly derived."""
  161. master_key = vault._get_master_key()
  162. assert isinstance(master_key, bytes)
  163. assert len(master_key) == 32 # 256 bits
  164. def test_invalid_master_key_length(self, vault, monkeypatch):
  165. """Test that invalid master key length raises error."""
  166. import base64
  167. # Set a valid base64 but wrong length (16 bytes instead of 32)
  168. short_key = base64.b64encode(b"0" * 16).decode()
  169. monkeypatch.setenv("MASTER_VAULT_KEY_V1", f"base64:{short_key}")
  170. with pytest.raises(ValueError, match="Master key must be 32 bytes \\(256 bits\\)"):
  171. vault._get_master_key()
  172. def test_missing_master_key(self, vault, monkeypatch):
  173. """Test that missing master key raises error."""
  174. monkeypatch.delenv("MASTER_VAULT_KEY_V1", raising=False)
  175. with pytest.raises(ValueError, match="Missing master key"):
  176. vault._get_master_key()