From 2826b74a65c8b15208d929da4005db5ba08e5bee Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Thu, 20 Feb 2025 10:55:21 +0100 Subject: [PATCH 1/8] feat: remove not needed encryption of secrets Instead use an uuid generator as we do for pii, and reuse same session store mechanism Closes: #929 --- src/codegate/pipeline/base.py | 8 -- src/codegate/pipeline/pii/analyzer.py | 67 ++++-------- src/codegate/pipeline/pii/manager.py | 23 ++-- src/codegate/pipeline/pii/pii.py | 21 ++-- src/codegate/pipeline/secrets/gatecrypto.py | 111 -------------------- src/codegate/pipeline/secrets/manager.py | 108 ++++--------------- src/codegate/pipeline/secrets/secrets.py | 9 +- src/codegate/session/session_store.py | 30 ++++++ tests/pipeline/pii/test_analyzer.py | 63 +++-------- tests/pipeline/secrets/test_manager.py | 9 +- 10 files changed, 120 insertions(+), 329 deletions(-) delete mode 100644 src/codegate/pipeline/secrets/gatecrypto.py create mode 100644 src/codegate/session/session_store.py diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index 0baa322a..f9ce39b6 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -30,16 +30,8 @@ def secure_cleanup(self): """Securely cleanup sensitive data for this session""" if self.manager is None or self.session_id == "": return - self.manager.cleanup_session(self.session_id) self.session_id = "" - - # Securely wipe the API key using the same method as secrets manager - if self.api_key is not None: - api_key_bytes = bytearray(self.api_key.encode()) - self.manager.crypto.wipe_bytearray(api_key_bytes) - self.api_key = None - self.model = None diff --git a/src/codegate/pipeline/pii/analyzer.py b/src/codegate/pipeline/pii/analyzer.py index a1ed5bed..75d5f299 100644 --- a/src/codegate/pipeline/pii/analyzer.py +++ b/src/codegate/pipeline/pii/analyzer.py @@ -7,41 +7,11 @@ from codegate.db.models import AlertSeverity from codegate.pipeline.base import PipelineContext +from codegate.session.session_store import SessionStore logger = structlog.get_logger("codegate.pii.analyzer") -class PiiSessionStore: - """ - A class to manage PII (Personally Identifiable Information) session storage. - - Attributes: - session_id (str): The unique identifier for the session. If not provided, a new UUID - is generated. mappings (Dict[str, str]): A dictionary to store mappings between UUID - placeholders and PII. - - Methods: - add_mapping(pii: str) -> str: - Adds a PII string to the session store and returns a UUID placeholder for it. - - get_pii(uuid_placeholder: str) -> str: - Retrieves the PII string associated with the given UUID placeholder. If the placeholder - is not found, returns the placeholder itself. - """ - - def __init__(self, session_id: str = None): - self.session_id = session_id or str(uuid.uuid4()) - self.mappings: Dict[str, str] = {} - - def add_mapping(self, pii: str) -> str: - uuid_placeholder = f"<{str(uuid.uuid4())}>" - self.mappings[uuid_placeholder] = pii - return uuid_placeholder - - def get_pii(self, uuid_placeholder: str) -> str: - return self.mappings.get(uuid_placeholder, uuid_placeholder) - - class PiiAnalyzer: """ PiiAnalyzer class for analyzing and anonymizing text containing PII. @@ -52,12 +22,12 @@ class PiiAnalyzer: Get or create the singleton instance of PiiAnalyzer. analyze: text (str): The text to analyze for PII. - Tuple[str, List[Dict[str, Any]], PiiSessionStore]: The anonymized text, a list of + Tuple[str, List[Dict[str, Any]], SessionStore]: The anonymized text, a list of found PII details, and the session store. entities (List[str]): The PII entities to analyze for. restore_pii: anonymized_text (str): The text with anonymized PII. - session_store (PiiSessionStore): The PiiSessionStore used for anonymization. + session_store (SessionStore): The SessionStore used for anonymization. str: The text with original PII restored. """ @@ -95,13 +65,13 @@ def __init__(self): # Create analyzer with custom NLP engine self.analyzer = AnalyzerEngine(nlp_engine=nlp_engine) self.anonymizer = AnonymizerEngine() - self.session_store = PiiSessionStore() + self.session_store = SessionStore() PiiAnalyzer._instance = self def analyze( - self, text: str, context: Optional[PipelineContext] = None - ) -> Tuple[str, List[Dict[str, Any]], PiiSessionStore]: + self, text: str, session_id: str, context: Optional[PipelineContext] = None + ) -> Tuple[str, List[Dict[str, Any]]]: # Prioritize credit card detection first entities = [ "PHONE_NUMBER", @@ -135,7 +105,7 @@ def analyze( anonymized_text = text for result in analyzer_results: pii_value = text[result.start : result.end] - uuid_placeholder = self.session_store.add_mapping(pii_value) + uuid_placeholder = self.session_store.add_mapping(session_id, pii_value) pii_info = { "type": result.entity_type, "value": pii_value, @@ -155,7 +125,7 @@ def analyze( uuid=uuid_placeholder, # Don't log the actual PII value for security value_length=len(pii_value), - session_id=self.session_store.session_id, + session_id=session_id, ) # Log summary of all PII found in this analysis @@ -176,30 +146,37 @@ def analyze( "PII analysis complete", total_pii_found=len(found_pii), pii_types=[p["type"] for p in found_pii], - session_id=self.session_store.session_id, + session_id=session_id, ) # Return the anonymized text, PII details, and session store - return anonymized_text, found_pii, self.session_store + return anonymized_text, found_pii # If no PII found, return original text, empty list, and session store - return text, [], self.session_store + return text, [] - def restore_pii(self, anonymized_text: str, session_store: PiiSessionStore) -> str: + def restore_pii(self, anonymized_text: str, session_id: str) -> str: """ Restore the original PII (Personally Identifiable Information) in the given anonymized text. This method replaces placeholders in the anonymized text with their corresponding original - PII values using the mappings stored in the provided PiiSessionStore. + PII values using the mappings stored in the provided SessionStore. Args: anonymized_text (str): The text containing placeholders for PII. - session_store (PiiSessionStore): The session store containing mappings of placeholders + session_store (SessionStore): The session store containing mappings of placeholders to original PII. Returns: str: The text with the original PII restored. """ - for uuid_placeholder, original_pii in session_store.mappings.items(): + session_data = self.session_store.get_by_session_id(session_id) + if not session_data: + logger.warning( + "No active PII session found for given session ID. Unable to restore PII." + ) + return anonymized_text + + for uuid_placeholder, original_pii in session_data.items(): anonymized_text = anonymized_text.replace(uuid_placeholder, original_pii) return anonymized_text diff --git a/src/codegate/pipeline/pii/manager.py b/src/codegate/pipeline/pii/manager.py index 54112713..265bb8ce 100644 --- a/src/codegate/pipeline/pii/manager.py +++ b/src/codegate/pipeline/pii/manager.py @@ -3,7 +3,8 @@ import structlog from codegate.pipeline.base import PipelineContext -from codegate.pipeline.pii.analyzer import PiiAnalyzer, PiiSessionStore +from codegate.pipeline.pii.analyzer import PiiAnalyzer +from codegate.session.session_store import SessionStore logger = structlog.get_logger("codegate") @@ -16,14 +17,14 @@ class PiiManager: Attributes: analyzer (PiiAnalyzer): The singleton instance of PiiAnalyzer used for PII detection and restoration. - session_store (PiiSessionStore): The session store for the current PII session. + session_store (SessionStore): The session store for the current PII session. Methods: __init__(): Initializes the PiiManager with the singleton PiiAnalyzer instance and sets the session store. - analyze(text: str) -> Tuple[str, List[Dict[str, Any]]]: + analyze(text: str, session_id: str) -> Tuple[str, List[Dict[str, Any]]]: Analyzes the given text for PII, anonymizes it, and logs the detected PII details. Args: text (str): The text to be analyzed for PII. @@ -31,7 +32,7 @@ class PiiManager: Tuple[str, List[Dict[str, Any]]]: A tuple containing the anonymized text and a list of found PII details. - restore_pii(anonymized_text: str) -> str: + restore_pii(anonymized_text: str, session_id: str ) -> str: Restores the PII in the given anonymized text using the current session. Args: anonymized_text (str): The text with anonymized PII to be restored. @@ -48,16 +49,16 @@ def __init__(self): self._session_store = self.analyzer.session_store @property - def session_store(self) -> PiiSessionStore: + def session_store(self) -> SessionStore: """Get the current session store.""" # Always return the analyzer's current session store return self.analyzer.session_store def analyze( - self, text: str, context: Optional[PipelineContext] = None + self, text: str, session_id: str, context: Optional[PipelineContext] = None ) -> Tuple[str, List[Dict[str, Any]]]: # Call analyzer and get results - anonymized_text, found_pii, _ = self.analyzer.analyze(text, context=context) + anonymized_text, found_pii = self.analyzer.analyze(text, session_id, context=context) # Log found PII details (without modifying the found_pii list) if found_pii: @@ -72,13 +73,9 @@ def analyze( # Return the exact same objects we got from the analyzer return anonymized_text, found_pii - def restore_pii(self, anonymized_text: str) -> str: + def restore_pii(self, anonymized_text: str, session_id: str) -> str: """ Restore PII in the given anonymized text using the current session. """ - if self.session_store is None: - logger.warning("No active PII session found. Unable to restore PII.") - return anonymized_text - # Use the analyzer's restore_pii method with the current session store - return self.analyzer.restore_pii(anonymized_text, self.session_store) + return self.analyzer.restore_pii(anonymized_text, session_id) diff --git a/src/codegate/pipeline/pii/pii.py b/src/codegate/pipeline/pii/pii.py index f0b9f271..b19ff5f3 100644 --- a/src/codegate/pipeline/pii/pii.py +++ b/src/codegate/pipeline/pii/pii.py @@ -1,4 +1,5 @@ from typing import Any, Dict, List, Optional +import uuid import regex as re import structlog @@ -37,7 +38,7 @@ class CodegatePii(PipelineStep): Processes the chat completion request to detect and redact PII. Updates the request with anonymized text and stores PII details in the context metadata. - restore_pii(anonymized_text: str) -> str: + restore_pii(anonymized_text: str, session_id: str) -> str: Restores the original PII from the anonymized text using the PiiManager. """ @@ -75,12 +76,15 @@ async def process( total_pii_found = 0 all_pii_details: List[Dict[str, Any]] = [] last_redacted_text = "" + session_id = context.session_id if hasattr(context, "session_id") else str(uuid.uuid4()) for i, message in enumerate(new_request["messages"]): if "content" in message and message["content"]: # This is where analyze and anonymize the text original_text = str(message["content"]) - anonymized_text, pii_details = self.pii_manager.analyze(original_text, context) + anonymized_text, pii_details = self.pii_manager.analyze( + original_text, session_id, context + ) if pii_details: total_pii_found += len(pii_details) @@ -99,6 +103,7 @@ async def process( context.metadata["redacted_pii_count"] = total_pii_found context.metadata["redacted_pii_details"] = all_pii_details context.metadata["redacted_text"] = last_redacted_text + context.metadata["session_id"] = session_id if total_pii_found > 0: context.metadata["pii_manager"] = self.pii_manager @@ -113,8 +118,8 @@ async def process( return PipelineResult(request=new_request, context=context) - def restore_pii(self, anonymized_text: str) -> str: - return self.pii_manager.restore_pii(anonymized_text) + def restore_pii(self, anonymized_text: str, session_id: str) -> str: + return self.pii_manager.restore_pii(anonymized_text, session_id) class PiiUnRedactionStep(OutputPipelineStep): @@ -151,7 +156,7 @@ def _is_complete_uuid(self, uuid_str: str) -> bool: """Check if the string is a complete UUID""" return bool(self.complete_uuid_pattern.match(uuid_str)) - async def process_chunk( + async def process_chunk( # noqa: C901 self, chunk: ModelResponse, context: OutputPipelineContext, @@ -162,6 +167,10 @@ async def process_chunk( return [chunk] content = chunk.choices[0].delta.content + session_id = input_context.metadata.get("session_id", "") + if not session_id: + logger.error("Could not get any session id, cannot process pii") + return [chunk] # Add current chunk to buffer if context.prefix_buffer: @@ -199,7 +208,7 @@ async def process_chunk( if pii_manager and pii_manager.session_store: # Restore original value from PII manager logger.debug("Attempting to restore PII from UUID marker") - original = pii_manager.session_store.get_pii(uuid_marker) + original = pii_manager.session_store.get_mapping(session_id, uuid_marker) logger.debug(f"Restored PII: {original}") result.append(original) else: diff --git a/src/codegate/pipeline/secrets/gatecrypto.py b/src/codegate/pipeline/secrets/gatecrypto.py deleted file mode 100644 index 859b025d..00000000 --- a/src/codegate/pipeline/secrets/gatecrypto.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -import time -from base64 import b64decode, b64encode - -import structlog -from cryptography.hazmat.primitives.ciphers.aead import AESGCM - -logger = structlog.get_logger("codegate") - - -class CodeGateCrypto: - """ - Manage session keys and provide encryption / decryption of tokens with replay protection. - Attributes: - session_keys (dict): A dictionary to store session keys with their associated timestamps. - SESSION_KEY_LIFETIME (int): The lifetime of a session key in seconds. - NONCE_SIZE (int): The size of the nonce used in AES GCM mode. - Methods: - generate_session_key(session_id): - Generates a session key with an associated timestamp. - get_session_key(session_id): - Retrieves a session key if it is still valid. - cleanup_expired_keys(): - Removes expired session keys from memory. - encrypt_token(token, session_id): - Encrypts a token with a session key and adds a timestamp for replay protection. - decrypt_token(encrypted_token, session_id): - Decrypts a token and validates its timestamp to prevent replay attacks. - wipe_bytearray(data): - Securely wipes a bytearray in-place. - """ - - def __init__(self): - self.session_keys = {} - self.SESSION_KEY_LIFETIME = 600 # 10 minutes - self.NONCE_SIZE = 12 # AES GCM recommended nonce size - - def generate_session_key(self, session_id): - """Generates a session key with an associated timestamp.""" - key = os.urandom(32) # Generate a 256-bit key - self.session_keys[session_id] = (key, time.time()) - return key - - def get_session_key(self, session_id): - """Retrieves a session key if it is still valid.""" - key_data = self.session_keys.get(session_id) - if key_data: - key, timestamp = key_data - if time.time() - timestamp < self.SESSION_KEY_LIFETIME: - return key - else: - # Key has expired - del self.session_keys[session_id] - return None - - def cleanup_expired_keys(self): - """Removes expired session keys from memory.""" - now = time.time() - expired_keys = [ - session_id - for session_id, (key, timestamp) in self.session_keys.items() - if now - timestamp >= self.SESSION_KEY_LIFETIME - ] - for session_id in expired_keys: - del self.session_keys[session_id] - - def encrypt_token(self, token, session_id): - """Encrypts a token with a session key and adds a timestamp for replay protection.""" - key = self.generate_session_key(session_id) - nonce = os.urandom(self.NONCE_SIZE) - timestamp = int(time.time()) - data = f"{token}:{timestamp}".encode() # Append timestamp to token - - aesgcm = AESGCM(key) - ciphertext = aesgcm.encrypt(nonce, data, None) # None for no associated data - - # Combine nonce and ciphertext (which includes the authentication tag) - encrypted_token = b64encode(nonce + ciphertext).decode() - return encrypted_token - - def decrypt_token(self, encrypted_token, session_id): - """Decrypts a token and validates its timestamp to prevent replay attacks.""" - key = self.get_session_key(session_id) - if not key: - raise ValueError("Session key expired or invalid.") - - encrypted_data = b64decode(encrypted_token) - nonce = encrypted_data[: self.NONCE_SIZE] - ciphertext = encrypted_data[self.NONCE_SIZE :] # Includes authentication tag - - aesgcm = AESGCM(key) - try: - decrypted_data = aesgcm.decrypt( - nonce, ciphertext, None - ).decode() # None for no associated data - except Exception as e: - raise ValueError("Decryption failed: Invalid token or tampering detected.") from e - - token, timestamp = decrypted_data.rsplit(":", 1) - if time.time() - int(timestamp) > self.SESSION_KEY_LIFETIME: - raise ValueError("Token has expired.") - - return token - - def wipe_bytearray(self, data): - """Securely wipes a bytearray in-place.""" - if not isinstance(data, bytearray): - raise ValueError("Only bytearray objects can be securely wiped.") - for i in range(len(data)): - data[i] = 0 # Overwrite each byte with 0 - logger.info("Sensitive data securely wiped from memory.") diff --git a/src/codegate/pipeline/secrets/manager.py b/src/codegate/pipeline/secrets/manager.py index bef07c75..4b852951 100644 --- a/src/codegate/pipeline/secrets/manager.py +++ b/src/codegate/pipeline/secrets/manager.py @@ -1,30 +1,20 @@ -from typing import NamedTuple, Optional +import json +from typing import Optional import structlog -from codegate.pipeline.secrets.gatecrypto import CodeGateCrypto +from codegate.session.session_store import SessionStore logger = structlog.get_logger("codegate") -class SecretEntry(NamedTuple): - """Represents a stored secret""" - - original: str - encrypted: str - service: str - secret_type: str - - class SecretsManager: """Manages encryption, storage and retrieval of secrets""" def __init__(self): - self.crypto = CodeGateCrypto() - self._session_store: dict[str, dict[str, SecretEntry]] = {} - self._encrypted_to_session: dict[str, str] = {} # Reverse lookup index + self.session_store = SessionStore() - def store_secret(self, value: str, service: str, secret_type: str, session_id: str) -> str: + def store_secret(self, session_id: str, value: str, service: str, secret_type: str) -> str: """ Encrypts and stores a secret value. Returns the encrypted value. @@ -35,83 +25,23 @@ def store_secret(self, value: str, service: str, secret_type: str, session_id: s raise ValueError("Service must be provided") if not secret_type: raise ValueError("Secret type must be provided") - if not session_id: - raise ValueError("Session ID must be provided") - encrypted_value = self.crypto.encrypt_token(value, session_id) - - # Store mappings - session_secrets = self._session_store.get(session_id, {}) - session_secrets[encrypted_value] = SecretEntry( - original=value, - encrypted=encrypted_value, - service=service, - secret_type=secret_type, + uuid_placeholder = self.session_store.add_mapping( + session_id, + json.dumps({"original": value, "service": service, "secret_type": secret_type}), ) - self._session_store[session_id] = session_secrets - self._encrypted_to_session[encrypted_value] = session_id - - logger.debug("Stored secret", service=service, type=secret_type, encrypted=encrypted_value) - - return encrypted_value + logger.debug( + "Stored secret", service=service, type=secret_type, placeholder=uuid_placeholder + ) + return uuid_placeholder - def get_original_value(self, encrypted_value: str, session_id: str) -> Optional[str]: + def get_original_value(self, session_id: str, uuid_placeholder: str) -> Optional[str]: """Retrieve original value for an encrypted value""" - try: - stored_session_id = self._encrypted_to_session.get(encrypted_value) - if stored_session_id == session_id: - session_secrets = self._session_store[session_id].get(encrypted_value) - if session_secrets: - return session_secrets.original - except Exception as e: - logger.error("Error retrieving secret", error=str(e)) + secret_entry_json = self.session_store.get_mapping(session_id, uuid_placeholder) + if secret_entry_json: + secret_entry = json.loads(secret_entry_json) + return secret_entry.get("original") return None - def get_by_session_id(self, session_id: str) -> Optional[SecretEntry]: - """Get stored data by session ID""" - return self._session_store.get(session_id) - - def cleanup(self): - """Securely wipe sensitive data""" - try: - # Convert and wipe original values - for secrets in self._session_store.values(): - for entry in secrets.values(): - original_bytes = bytearray(entry.original.encode()) - self.crypto.wipe_bytearray(original_bytes) - - # Clear the dictionaries - self._session_store.clear() - self._encrypted_to_session.clear() - - logger.info("Secrets manager data securely wiped") - except Exception as e: - logger.error("Error during secure cleanup", error=str(e)) - - def cleanup_session(self, session_id: str): - """ - Remove a specific session's secrets and perform secure cleanup. - - Args: - session_id (str): The session identifier to remove - """ - try: - # Get the secret entry for the session - secrets = self._session_store.get(session_id, {}) - - for entry in secrets.values(): - # Securely wipe the original value - original_bytes = bytearray(entry.original.encode()) - self.crypto.wipe_bytearray(original_bytes) - - # Remove the encrypted value from the reverse lookup index - self._encrypted_to_session.pop(entry.encrypted, None) - - # Remove the session from the store - self._session_store.pop(session_id, None) - - logger.debug("Session secrets securely removed", session_id=session_id) - else: - logger.debug("No secrets found for session", session_id=session_id) - except Exception as e: - logger.error("Error during session cleanup", session_id=session_id, error=str(e)) + def cleanup_session(self, session_id): + self.session_store.cleanup_session(session_id) diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index 184c3ba3..a56b71e9 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -179,15 +179,16 @@ def __init__( self._session_id = session_id self._context = context self._name = "codegate-secrets" + super().__init__() def _hide_secret(self, match: Match) -> str: # Encrypt and store the value encrypted_value = self._secrets_manager.store_secret( + self._session_id, match.value, match.service, match.type, - self._session_id, ) return f"REDACTED<${encrypted_value}>" @@ -428,7 +429,13 @@ async def process_chunk( encrypted_value = match.group(1) if encrypted_value.startswith("$"): encrypted_value = encrypted_value[1:] + + session_id = context.sensitive.session_id + if not session_id: + raise ValueError("Session ID not found in context") + original_value = input_context.sensitive.manager.get_original_value( + session_id, encrypted_value, input_context.sensitive.session_id, ) diff --git a/src/codegate/session/session_store.py b/src/codegate/session/session_store.py new file mode 100644 index 00000000..a493524e --- /dev/null +++ b/src/codegate/session/session_store.py @@ -0,0 +1,30 @@ +from typing import Dict, Optional +import uuid + + +class SessionStore: + """ + A generic session store for managing data protection. + """ + + def __init__(self): + self.sessions: Dict[str, Dict[str, str]] = {} + + def add_mapping(self, session_id: str, data: str) -> str: + uuid_placeholder = f"#{str(uuid.uuid4())}#" + if session_id not in self.sessions: + self.sessions[session_id] = {} + self.sessions[session_id][uuid_placeholder] = data + return uuid_placeholder + + def get_mapping(self, session_id: str, uuid_placeholder: str) -> Optional[str]: + return self.sessions.get(session_id, {}).get(uuid_placeholder) + + def cleanup_session(self, session_id: str): + """Clears all stored mappings for a specific session.""" + if session_id in self.sessions: + del self.sessions[session_id] + + def cleanup(self): + """Clears all stored mappings for all sessions.""" + self.sessions.clear() diff --git a/tests/pipeline/pii/test_analyzer.py b/tests/pipeline/pii/test_analyzer.py index 8d5a7c6e..7d0fc55f 100644 --- a/tests/pipeline/pii/test_analyzer.py +++ b/tests/pipeline/pii/test_analyzer.py @@ -3,44 +3,8 @@ import pytest from presidio_analyzer import RecognizerResult -from codegate.pipeline.pii.analyzer import PiiAnalyzer, PiiSessionStore - - -class TestPiiSessionStore: - def test_init_with_session_id(self): - session_id = "test-session" - store = PiiSessionStore(session_id) - assert store.session_id == session_id - assert store.mappings == {} - - def test_init_without_session_id(self): - store = PiiSessionStore() - assert isinstance(store.session_id, str) - assert len(store.session_id) > 0 - assert store.mappings == {} - - def test_add_mapping(self): - store = PiiSessionStore() - pii = "test@example.com" - placeholder = store.add_mapping(pii) - - assert placeholder.startswith("<") - assert placeholder.endswith(">") - assert store.mappings[placeholder] == pii - - def test_get_pii_existing(self): - store = PiiSessionStore() - pii = "test@example.com" - placeholder = store.add_mapping(pii) - - result = store.get_pii(placeholder) - assert result == pii - - def test_get_pii_nonexistent(self): - store = PiiSessionStore() - placeholder = "" - result = store.get_pii(placeholder) - assert result == placeholder +from codegate.pipeline.pii.analyzer import PiiAnalyzer +from codegate.session.session_store import SessionStore class TestPiiAnalyzer: @@ -112,7 +76,7 @@ def test_analyze_no_pii(self, analyzer, mock_analyzer_engine): assert result_text == text assert found_pii == [] - assert isinstance(session_store, PiiSessionStore) + assert isinstance(session_store, SessionStore) def test_analyze_with_pii(self, analyzer, mock_analyzer_engine): text = "My email is test@example.com" @@ -141,31 +105,32 @@ def test_analyze_with_pii(self, analyzer, mock_analyzer_engine): assert session_store.get_pii(placeholder) == "test@example.com" def test_restore_pii(self, analyzer): - session_store = PiiSessionStore() + session_store = SessionStore() original_text = "test@example.com" - placeholder = session_store.add_mapping(original_text) + session_id = "session-id" + placeholder = session_store.add_mapping(session_id, original_text) anonymized_text = f"My email is {placeholder}" - restored_text = analyzer.restore_pii(anonymized_text, session_store) + restored_text = analyzer.restore_pii(anonymized_text, session_id) assert restored_text == f"My email is {original_text}" def test_restore_pii_multiple(self, analyzer): - session_store = PiiSessionStore() + session_store = SessionStore() email = "test@example.com" phone = "123-456-7890" - email_placeholder = session_store.add_mapping(email) - phone_placeholder = session_store.add_mapping(phone) + session_id = "session-id" + email_placeholder = session_store.add_mapping(session_id, email) + phone_placeholder = session_store.add_mapping(session_id, phone) anonymized_text = f"Email: {email_placeholder}, Phone: {phone_placeholder}" - restored_text = analyzer.restore_pii(anonymized_text, session_store) + restored_text = analyzer.restore_pii(anonymized_text, session_id) assert restored_text == f"Email: {email}, Phone: {phone}" def test_restore_pii_no_placeholders(self, analyzer): - session_store = PiiSessionStore() text = "No PII here" - - restored_text = analyzer.restore_pii(text, session_store) + session_id = "session-id" + restored_text = analyzer.restore_pii(text, session_id) assert restored_text == text diff --git a/tests/pipeline/secrets/test_manager.py b/tests/pipeline/secrets/test_manager.py index 177e8f3f..d5c2d06d 100644 --- a/tests/pipeline/secrets/test_manager.py +++ b/tests/pipeline/secrets/test_manager.py @@ -7,7 +7,6 @@ class TestSecretsManager: def setup_method(self): """Setup a fresh SecretsManager for each test""" self.manager = SecretsManager() - self.test_session = "test_session_id" self.test_value = "super_secret_value" self.test_service = "test_service" self.test_type = "api_key" @@ -15,9 +14,7 @@ def setup_method(self): def test_store_secret(self): """Test basic secret storage and retrieval""" # Store a secret - encrypted = self.manager.store_secret( - self.test_value, self.test_service, self.test_type, self.test_session - ) + encrypted = self.manager.store_secret(self.test_value, self.test_service, self.test_type) # Verify the secret was stored stored = self.manager.get_by_session_id(self.test_session) @@ -30,9 +27,7 @@ def test_store_secret(self): def test_get_original_value_wrong_session(self): """Test that secrets can't be accessed with wrong session ID""" - encrypted = self.manager.store_secret( - self.test_value, self.test_service, self.test_type, self.test_session - ) + encrypted = self.manager.store_secret(self.test_value, self.test_service, self.test_type) # Try to retrieve with wrong session ID wrong_session = "wrong_session_id" From a615e0345210ca9c0b9a65704cfbe5513aed961d Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Thu, 20 Feb 2025 15:05:43 +0100 Subject: [PATCH 2/8] fix tests --- src/codegate/pipeline/pii/analyzer.py | 6 +- src/codegate/pipeline/pii/manager.py | 16 ++- src/codegate/pipeline/pii/pii.py | 20 +-- src/codegate/pipeline/secrets/manager.py | 18 ++- src/codegate/pipeline/secrets/secrets.py | 3 +- src/codegate/session/session_store.py | 3 + tests/pipeline/pii/test_analyzer.py | 25 ++-- tests/pipeline/pii/test_pi.py | 10 +- tests/pipeline/pii/test_pii_manager.py | 56 ++++---- tests/pipeline/secrets/test_gatecrypto.py | 157 ---------------------- tests/pipeline/secrets/test_manager.py | 70 +++++----- tests/pipeline/secrets/test_secrets.py | 4 +- 12 files changed, 127 insertions(+), 261 deletions(-) delete mode 100644 tests/pipeline/secrets/test_gatecrypto.py diff --git a/src/codegate/pipeline/pii/analyzer.py b/src/codegate/pipeline/pii/analyzer.py index 75d5f299..041cbc01 100644 --- a/src/codegate/pipeline/pii/analyzer.py +++ b/src/codegate/pipeline/pii/analyzer.py @@ -70,7 +70,7 @@ def __init__(self): PiiAnalyzer._instance = self def analyze( - self, text: str, session_id: str, context: Optional[PipelineContext] = None + self, session_id: str, text: str, context: Optional[PipelineContext] = None ) -> Tuple[str, List[Dict[str, Any]]]: # Prioritize credit card detection first entities = [ @@ -155,7 +155,7 @@ def analyze( # If no PII found, return original text, empty list, and session store return text, [] - def restore_pii(self, anonymized_text: str, session_id: str) -> str: + def restore_pii(self, session_id: str, anonymized_text: str) -> str: """ Restore the original PII (Personally Identifiable Information) in the given anonymized text. @@ -164,7 +164,7 @@ def restore_pii(self, anonymized_text: str, session_id: str) -> str: Args: anonymized_text (str): The text containing placeholders for PII. - session_store (SessionStore): The session store containing mappings of placeholders + session_id (str): The session id containing mappings of placeholders to original PII. Returns: diff --git a/src/codegate/pipeline/pii/manager.py b/src/codegate/pipeline/pii/manager.py index 265bb8ce..0b847bee 100644 --- a/src/codegate/pipeline/pii/manager.py +++ b/src/codegate/pipeline/pii/manager.py @@ -24,17 +24,19 @@ class PiiManager: Initializes the PiiManager with the singleton PiiAnalyzer instance and sets the session store. - analyze(text: str, session_id: str) -> Tuple[str, List[Dict[str, Any]]]: + analyze(session_id: str, text: str) -> Tuple[str, List[Dict[str, Any]]]: Analyzes the given text for PII, anonymizes it, and logs the detected PII details. Args: + session_id (str): The session id to store the PII. text (str): The text to be analyzed for PII. Returns: Tuple[str, List[Dict[str, Any]]]: A tuple containing the anonymized text and a list of found PII details. - restore_pii(anonymized_text: str, session_id: str ) -> str: + restore_pii(session_id: str, anonymized_text: st ) -> str: Restores the PII in the given anonymized text using the current session. Args: + session_id (str): The session id for the PII to be restored. anonymized_text (str): The text with anonymized PII to be restored. Returns: str: The text with restored PII. @@ -55,10 +57,10 @@ def session_store(self) -> SessionStore: return self.analyzer.session_store def analyze( - self, text: str, session_id: str, context: Optional[PipelineContext] = None + self, session_id: str, text: str, context: Optional[PipelineContext] = None ) -> Tuple[str, List[Dict[str, Any]]]: # Call analyzer and get results - anonymized_text, found_pii = self.analyzer.analyze(text, session_id, context=context) + anonymized_text, found_pii = self.analyzer.analyze(session_id, text, context=context) # Log found PII details (without modifying the found_pii list) if found_pii: @@ -73,9 +75,11 @@ def analyze( # Return the exact same objects we got from the analyzer return anonymized_text, found_pii - def restore_pii(self, anonymized_text: str, session_id: str) -> str: + def restore_pii(self, session_id: str, anonymized_text: str) -> str: """ Restore PII in the given anonymized text using the current session. """ + if not session_id: + return anonymized_text # Use the analyzer's restore_pii method with the current session store - return self.analyzer.restore_pii(anonymized_text, session_id) + return self.analyzer.restore_pii(session_id, anonymized_text) diff --git a/src/codegate/pipeline/pii/pii.py b/src/codegate/pipeline/pii/pii.py index b19ff5f3..0f7b0e55 100644 --- a/src/codegate/pipeline/pii/pii.py +++ b/src/codegate/pipeline/pii/pii.py @@ -38,7 +38,7 @@ class CodegatePii(PipelineStep): Processes the chat completion request to detect and redact PII. Updates the request with anonymized text and stores PII details in the context metadata. - restore_pii(anonymized_text: str, session_id: str) -> str: + restore_pii(session_id: str, anonymized_text: str) -> str: Restores the original PII from the anonymized text using the PiiManager. """ @@ -83,7 +83,7 @@ async def process( # This is where analyze and anonymize the text original_text = str(message["content"]) anonymized_text, pii_details = self.pii_manager.analyze( - original_text, session_id, context + session_id, original_text, context ) if pii_details: @@ -118,8 +118,8 @@ async def process( return PipelineResult(request=new_request, context=context) - def restore_pii(self, anonymized_text: str, session_id: str) -> str: - return self.pii_manager.restore_pii(anonymized_text, session_id) + def restore_pii(self, session_id: str, anonymized_text: str) -> str: + return self.pii_manager.restore_pii(session_id, anonymized_text) class PiiUnRedactionStep(OutputPipelineStep): @@ -141,12 +141,12 @@ class PiiUnRedactionStep(OutputPipelineStep): """ def __init__(self): - self.redacted_pattern = re.compile(r"<([0-9a-f-]{0,36})>") + self.redacted_pattern = re.compile(r"#([0-9a-f-]{0,36})#") self.complete_uuid_pattern = re.compile( r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" ) # noqa: E501 - self.marker_start = "<" - self.marker_end = ">" + self.marker_start = "#" + self.marker_end = "#" @property def name(self) -> str: @@ -181,13 +181,13 @@ async def process_chunk( # noqa: C901 current_pos = 0 result = [] while current_pos < len(content): - start_idx = content.find("<", current_pos) + start_idx = content.find("#", current_pos) if start_idx == -1: # No more markers!, add remaining content result.append(content[current_pos:]) break - end_idx = content.find(">", start_idx) + end_idx = content.find("#", start_idx + 1) if end_idx == -1: # Incomplete marker, buffer the rest context.prefix_buffer = content[current_pos:] @@ -199,7 +199,7 @@ async def process_chunk( # noqa: C901 # Extract potential UUID if it's a valid format! uuid_marker = content[start_idx : end_idx + 1] - uuid_value = uuid_marker[1:-1] # Remove < > + uuid_value = uuid_marker[1:-1] # Remove # # if self._is_complete_uuid(uuid_value): # Get the PII manager from context metadata diff --git a/src/codegate/pipeline/secrets/manager.py b/src/codegate/pipeline/secrets/manager.py index 4b852951..9476eef9 100644 --- a/src/codegate/pipeline/secrets/manager.py +++ b/src/codegate/pipeline/secrets/manager.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Dict, Optional import structlog @@ -19,6 +19,9 @@ def store_secret(self, session_id: str, value: str, service: str, secret_type: s Encrypts and stores a secret value. Returns the encrypted value. """ + if not session_id: + raise ValueError("Session id must be provided") + if not value: raise ValueError("Value must be provided") if not service: @@ -35,6 +38,16 @@ def store_secret(self, session_id: str, value: str, service: str, secret_type: s ) return uuid_placeholder + def get_by_session_id(self, session_id: str) -> Optional[Dict]: + session_data = self.session_store.get_by_session_id(session_id) + if not session_data: + return None + # Convert all string values to dictionary objects using json.loads + return { + key: json.loads(value) if isinstance(value, str) else value + for key, value in session_data.items() + } + def get_original_value(self, session_id: str, uuid_placeholder: str) -> Optional[str]: """Retrieve original value for an encrypted value""" secret_entry_json = self.session_store.get_mapping(session_id, uuid_placeholder) @@ -45,3 +58,6 @@ def get_original_value(self, session_id: str, uuid_placeholder: str) -> Optional def cleanup_session(self, session_id): self.session_store.cleanup_session(session_id) + + def cleanup(self): + self.session_store.cleanup() diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index a56b71e9..68973a6f 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -430,14 +430,13 @@ async def process_chunk( if encrypted_value.startswith("$"): encrypted_value = encrypted_value[1:] - session_id = context.sensitive.session_id + session_id = input_context.sensitive.session_id if not session_id: raise ValueError("Session ID not found in context") original_value = input_context.sensitive.manager.get_original_value( session_id, encrypted_value, - input_context.sensitive.session_id, ) if original_value is None: diff --git a/src/codegate/session/session_store.py b/src/codegate/session/session_store.py index a493524e..5e508847 100644 --- a/src/codegate/session/session_store.py +++ b/src/codegate/session/session_store.py @@ -17,6 +17,9 @@ def add_mapping(self, session_id: str, data: str) -> str: self.sessions[session_id][uuid_placeholder] = data return uuid_placeholder + def get_by_session_id(self, session_id: str) -> Optional[Dict]: + return self.sessions.get(session_id, None) + def get_mapping(self, session_id: str, uuid_placeholder: str) -> Optional[str]: return self.sessions.get(session_id, {}).get(uuid_placeholder) diff --git a/tests/pipeline/pii/test_analyzer.py b/tests/pipeline/pii/test_analyzer.py index 7d0fc55f..618549df 100644 --- a/tests/pipeline/pii/test_analyzer.py +++ b/tests/pipeline/pii/test_analyzer.py @@ -70,16 +70,17 @@ def test_singleton_pattern(self): def test_analyze_no_pii(self, analyzer, mock_analyzer_engine): text = "Hello world" + session_id = "session-id" mock_analyzer_engine.analyze.return_value = [] - result_text, found_pii, session_store = analyzer.analyze(text) + result_text, found_pii = analyzer.analyze(session_id, text) assert result_text == text assert found_pii == [] - assert isinstance(session_store, SessionStore) def test_analyze_with_pii(self, analyzer, mock_analyzer_engine): text = "My email is test@example.com" + session_id = "session-id" email_pii = RecognizerResult( entity_type="EMAIL_ADDRESS", start=12, @@ -88,7 +89,7 @@ def test_analyze_with_pii(self, analyzer, mock_analyzer_engine): ) mock_analyzer_engine.analyze.return_value = [email_pii] - result_text, found_pii, session_store = analyzer.analyze(text) + result_text, found_pii = analyzer.analyze(session_id, text) assert len(found_pii) == 1 pii_info = found_pii[0] @@ -101,36 +102,32 @@ def test_analyze_with_pii(self, analyzer, mock_analyzer_engine): # Verify the placeholder was used to replace the PII placeholder = pii_info["uuid_placeholder"] assert result_text == f"My email is {placeholder}" - # Verify the mapping was stored - assert session_store.get_pii(placeholder) == "test@example.com" def test_restore_pii(self, analyzer): - session_store = SessionStore() original_text = "test@example.com" session_id = "session-id" - placeholder = session_store.add_mapping(session_id, original_text) - anonymized_text = f"My email is {placeholder}" - restored_text = analyzer.restore_pii(anonymized_text, session_id) + placeholder = analyzer.session_store.add_mapping(session_id, original_text) + anonymized_text = f"My email is {placeholder}" + restored_text = analyzer.restore_pii(session_id, anonymized_text) assert restored_text == f"My email is {original_text}" def test_restore_pii_multiple(self, analyzer): - session_store = SessionStore() email = "test@example.com" phone = "123-456-7890" session_id = "session-id" - email_placeholder = session_store.add_mapping(session_id, email) - phone_placeholder = session_store.add_mapping(session_id, phone) + email_placeholder = analyzer.session_store.add_mapping(session_id, email) + phone_placeholder = analyzer.session_store.add_mapping(session_id, phone) anonymized_text = f"Email: {email_placeholder}, Phone: {phone_placeholder}" - restored_text = analyzer.restore_pii(anonymized_text, session_id) + restored_text = analyzer.restore_pii(session_id, anonymized_text) assert restored_text == f"Email: {email}, Phone: {phone}" def test_restore_pii_no_placeholders(self, analyzer): text = "No PII here" session_id = "session-id" - restored_text = analyzer.restore_pii(text, session_id) + restored_text = analyzer.restore_pii(session_id, text) assert restored_text == text diff --git a/tests/pipeline/pii/test_pi.py b/tests/pipeline/pii/test_pi.py index 6578a7b6..8fa52acf 100644 --- a/tests/pipeline/pii/test_pi.py +++ b/tests/pipeline/pii/test_pi.py @@ -96,9 +96,10 @@ async def test_process_with_pii(self, pii_step): def test_restore_pii(self, pii_step): anonymized_text = "My email is " original_text = "My email is test@example.com" + session_id = "session-id" pii_step.pii_manager.restore_pii = MagicMock(return_value=original_text) - restored = pii_step.restore_pii(anonymized_text) + restored = pii_step.restore_pii(session_id, anonymized_text) assert restored == original_text @@ -148,7 +149,7 @@ async def test_process_chunk_with_uuid(self, unredaction_step): StreamingChoices( finish_reason=None, index=0, - delta=Delta(content=f"Text with <{uuid}>"), + delta=Delta(content=f"Text with #{uuid}#"), logprobs=None, ) ], @@ -157,17 +158,16 @@ async def test_process_chunk_with_uuid(self, unredaction_step): object="chat.completion.chunk", ) context = OutputPipelineContext() - input_context = PipelineContext() + input_context = PipelineContext(metadata={"session_id": "session-id"}) # Mock PII manager in input context mock_pii_manager = MagicMock() mock_session = MagicMock() - mock_session.get_pii = MagicMock(return_value="test@example.com") + mock_session.get_mapping = MagicMock(return_value="test@example.com") mock_pii_manager.session_store = mock_session input_context.metadata["pii_manager"] = mock_pii_manager result = await unredaction_step.process_chunk(chunk, context, input_context) - assert result[0].choices[0].delta.content == "Text with test@example.com" diff --git a/tests/pipeline/pii/test_pii_manager.py b/tests/pipeline/pii/test_pii_manager.py index 229b7314..aa363240 100644 --- a/tests/pipeline/pii/test_pii_manager.py +++ b/tests/pipeline/pii/test_pii_manager.py @@ -2,7 +2,7 @@ import pytest -from codegate.pipeline.pii.analyzer import PiiSessionStore +from codegate.pipeline.pii.analyzer import SessionStore from codegate.pipeline.pii.manager import PiiManager @@ -10,7 +10,7 @@ class TestPiiManager: @pytest.fixture def session_store(self): """Create a session store that will be shared between the mock and manager""" - return PiiSessionStore() + return SessionStore() @pytest.fixture def mock_analyzer(self, session_store): @@ -36,18 +36,20 @@ def test_init(self, manager, mock_analyzer): def test_analyze_no_pii(self, manager, mock_analyzer): text = "Hello CodeGate" + session_id = "session-id" session_store = mock_analyzer.session_store - mock_analyzer.analyze.return_value = (text, [], session_store) + mock_analyzer.analyze.return_value = (text, []) - anonymized_text, found_pii = manager.analyze(text) + anonymized_text, found_pii = manager.analyze(session_id, text) assert anonymized_text == text assert found_pii == [] assert manager.session_store is session_store - mock_analyzer.analyze.assert_called_once_with(text, context=None) + mock_analyzer.analyze.assert_called_once_with(session_id, text, context=None) def test_analyze_with_pii(self, manager, mock_analyzer): text = "My email is test@example.com" + session_id = "session-id" session_store = mock_analyzer.session_store placeholder = "" pii_details = [ @@ -61,46 +63,52 @@ def test_analyze_with_pii(self, manager, mock_analyzer): } ] anonymized_text = f"My email is {placeholder}" - session_store.mappings[placeholder] = "test@example.com" - mock_analyzer.analyze.return_value = (anonymized_text, pii_details, session_store) + mock_analyzer.analyze.return_value = (anonymized_text, pii_details) + session_store.sessions[session_id] = {placeholder: "test@example.com"} - result_text, found_pii = manager.analyze(text) + result_text, found_pii = manager.analyze(session_id, text) assert "My email is <" in result_text assert ">" in result_text assert found_pii == pii_details assert manager.session_store is session_store - assert manager.session_store.mappings[placeholder] == "test@example.com" - mock_analyzer.analyze.assert_called_once_with(text, context=None) - def test_restore_pii_no_session(self, manager, mock_analyzer): - text = "Anonymized text" - # Create a new session store that's None - mock_analyzer.session_store = None + assert manager.session_store.sessions[session_id][placeholder] == "test@example.com" + mock_analyzer.analyze.assert_called_once_with(session_id, text, context=None) - restored_text = manager.restore_pii(text) + def test_restore_pii_no_session(self, manager): + text = "Anonymized text" + session_id = "" + restored_text = manager.restore_pii(session_id, text) assert restored_text == text def test_restore_pii_with_session(self, manager, mock_analyzer): - anonymized_text = "My email is " + anonymized_text = "My email is #test-uuid#" original_text = "My email is test@example.com" - manager.session_store.mappings[""] = "test@example.com" + session_id = "session-id" + session_store = mock_analyzer.session_store + session_store.sessions[session_id] = {"#test-uuid#": "test@example.com"} mock_analyzer.restore_pii.return_value = original_text - restored_text = manager.restore_pii(anonymized_text) + restored_text = manager.restore_pii(session_id, anonymized_text) assert restored_text == original_text - mock_analyzer.restore_pii.assert_called_once_with(anonymized_text, manager.session_store) + mock_analyzer.restore_pii.assert_called_once_with(session_id, anonymized_text) def test_restore_pii_multiple_placeholders(self, manager, mock_analyzer): - anonymized_text = "Email: , Phone: " + anonymized_text = "Email: #uuid1#, Phone: #uuid2#" original_text = "Email: test@example.com, Phone: 123-456-7890" - manager.session_store.mappings[""] = "test@example.com" - manager.session_store.mappings[""] = "123-456-7890" + session_id = "session-id" + session_store = mock_analyzer.session_store + session_store.sessions[session_id] = { + "#uuid1#": "test@example.com", + "#uuid2#": "123-456-7890", + } + mock_analyzer.restore_pii.return_value = original_text - restored_text = manager.restore_pii(anonymized_text) + restored_text = manager.restore_pii(session_id, anonymized_text) assert restored_text == original_text - mock_analyzer.restore_pii.assert_called_once_with(anonymized_text, manager.session_store) + mock_analyzer.restore_pii.assert_called_once_with(session_id, anonymized_text) diff --git a/tests/pipeline/secrets/test_gatecrypto.py b/tests/pipeline/secrets/test_gatecrypto.py deleted file mode 100644 index b7de4b19..00000000 --- a/tests/pipeline/secrets/test_gatecrypto.py +++ /dev/null @@ -1,157 +0,0 @@ -import time - -import pytest - -from codegate.pipeline.secrets.gatecrypto import CodeGateCrypto - - -@pytest.fixture -def crypto(): - return CodeGateCrypto() - - -def test_generate_session_key(crypto): - session_id = "test_session" - key = crypto.generate_session_key(session_id) - - assert len(key) == 32 # AES-256 key size - assert session_id in crypto.session_keys - assert isinstance(crypto.session_keys[session_id], tuple) - assert len(crypto.session_keys[session_id]) == 2 - - -def test_get_session_key(crypto): - session_id = "test_session" - original_key = crypto.generate_session_key(session_id) - retrieved_key = crypto.get_session_key(session_id) - - assert original_key == retrieved_key - - -def test_get_expired_session_key(crypto): - session_id = "test_session" - crypto.generate_session_key(session_id) - - # Manually expire the key by modifying its timestamp - key, _ = crypto.session_keys[session_id] - crypto.session_keys[session_id] = (key, time.time() - (crypto.SESSION_KEY_LIFETIME + 10)) - - retrieved_key = crypto.get_session_key(session_id) - assert retrieved_key is None - assert session_id not in crypto.session_keys - - -def test_cleanup_expired_keys(crypto): - # Generate multiple session keys - session_ids = ["session1", "session2", "session3"] - for session_id in session_ids: - crypto.generate_session_key(session_id) - - # Manually expire some keys - key, _ = crypto.session_keys["session1"] - crypto.session_keys["session1"] = (key, time.time() - (crypto.SESSION_KEY_LIFETIME + 10)) - key, _ = crypto.session_keys["session2"] - crypto.session_keys["session2"] = (key, time.time() - (crypto.SESSION_KEY_LIFETIME + 10)) - - crypto.cleanup_expired_keys() - - assert "session1" not in crypto.session_keys - assert "session2" not in crypto.session_keys - assert "session3" in crypto.session_keys - - -def test_encrypt_decrypt_token(crypto): - session_id = "test_session" - original_token = "sensitive_data_123" - - encrypted_token = crypto.encrypt_token(original_token, session_id) - decrypted_token = crypto.decrypt_token(encrypted_token, session_id) - - assert decrypted_token == original_token - - -def test_decrypt_with_expired_session(crypto): - session_id = "test_session" - token = "sensitive_data_123" - - encrypted_token = crypto.encrypt_token(token, session_id) - - # Manually expire the session key - key, _ = crypto.session_keys[session_id] - crypto.session_keys[session_id] = (key, time.time() - (crypto.SESSION_KEY_LIFETIME + 10)) - - with pytest.raises(ValueError, match="Session key expired or invalid."): - crypto.decrypt_token(encrypted_token, session_id) - - -def test_decrypt_with_invalid_session(crypto): - session_id = "test_session" - token = "sensitive_data_123" - - encrypted_token = crypto.encrypt_token(token, session_id) - - with pytest.raises(ValueError, match="Session key expired or invalid."): - crypto.decrypt_token(encrypted_token, "invalid_session") - - -def test_decrypt_with_expired_token(crypto, monkeypatch): - session_id = "test_session" - token = "sensitive_data_123" - current_time = time.time() - - # Mock time.time() for token encryption - monkeypatch.setattr(time, "time", lambda: current_time) - - # Generate token with current timestamp - encrypted_token = crypto.encrypt_token(token, session_id) - - # Mock time.time() to return a future timestamp for decryption - future_time = current_time + crypto.SESSION_KEY_LIFETIME + 10 - monkeypatch.setattr(time, "time", lambda: future_time) - - # Keep the original key but update its timestamp to keep it valid - key, _ = crypto.session_keys[session_id] - crypto.session_keys[session_id] = (key, future_time) - - with pytest.raises(ValueError, match="Token has expired."): - crypto.decrypt_token(encrypted_token, session_id) - - -def test_wipe_bytearray(crypto): - # Create a bytearray with sensitive data - sensitive_data = bytearray(b"sensitive_information") - original_content = sensitive_data.copy() - - # Wipe the data - crypto.wipe_bytearray(sensitive_data) - - # Verify all bytes are zeroed - assert all(byte == 0 for byte in sensitive_data) - assert sensitive_data != original_content - - -def test_wipe_bytearray_invalid_input(crypto): - # Try to wipe a string instead of bytearray - with pytest.raises(ValueError, match="Only bytearray objects can be securely wiped."): - crypto.wipe_bytearray("not a bytearray") - - -def test_encrypt_decrypt_with_special_characters(crypto): - session_id = "test_session" - special_chars_token = "!@#$%^&*()_+-=[]{}|;:,.<>?" - - encrypted_token = crypto.encrypt_token(special_chars_token, session_id) - decrypted_token = crypto.decrypt_token(encrypted_token, session_id) - - assert decrypted_token == special_chars_token - - -def test_encrypt_decrypt_multiple_tokens(crypto): - session_id = "test_session" - tokens = ["token1", "token2", "token3"] - - # Encrypt and immediately decrypt each token - for token in tokens: - encrypted = crypto.encrypt_token(token, session_id) - decrypted = crypto.decrypt_token(encrypted, session_id) - assert decrypted == token diff --git a/tests/pipeline/secrets/test_manager.py b/tests/pipeline/secrets/test_manager.py index d5c2d06d..d7d97787 100644 --- a/tests/pipeline/secrets/test_manager.py +++ b/tests/pipeline/secrets/test_manager.py @@ -7,6 +7,7 @@ class TestSecretsManager: def setup_method(self): """Setup a fresh SecretsManager for each test""" self.manager = SecretsManager() + self.test_session = "session-id" self.test_value = "super_secret_value" self.test_service = "test_service" self.test_type = "api_key" @@ -14,24 +15,27 @@ def setup_method(self): def test_store_secret(self): """Test basic secret storage and retrieval""" # Store a secret - encrypted = self.manager.store_secret(self.test_value, self.test_service, self.test_type) + encrypted = self.manager.store_secret( + self.test_session, self.test_value, self.test_service, self.test_type + ) # Verify the secret was stored stored = self.manager.get_by_session_id(self.test_session) - assert isinstance(stored, dict) - assert stored[encrypted].original == self.test_value + assert stored[encrypted]["original"] == self.test_value # Verify encrypted value can be retrieved - retrieved = self.manager.get_original_value(encrypted, self.test_session) + retrieved = self.manager.get_original_value(self.test_session, encrypted) assert retrieved == self.test_value def test_get_original_value_wrong_session(self): """Test that secrets can't be accessed with wrong session ID""" - encrypted = self.manager.store_secret(self.test_value, self.test_service, self.test_type) + encrypted = self.manager.store_secret( + self.test_session, self.test_value, self.test_service, self.test_type + ) # Try to retrieve with wrong session ID wrong_session = "wrong_session_id" - retrieved = self.manager.get_original_value(encrypted, wrong_session) + retrieved = self.manager.get_original_value(wrong_session, encrypted) assert retrieved is None def test_get_original_value_nonexistent(self): @@ -45,19 +49,19 @@ def test_cleanup_session(self): session1 = "session1" session2 = "session2" - encrypted1 = self.manager.store_secret("secret1", "service1", "type1", session1) - encrypted2 = self.manager.store_secret("secret2", "service2", "type2", session2) + encrypted1 = self.manager.store_secret(session1, "secret1", "service1", "type1") + encrypted2 = self.manager.store_secret(session2, "secret2", "service2", "type2") # Clean up session1 self.manager.cleanup_session(session1) # Verify session1 secrets are gone assert self.manager.get_by_session_id(session1) is None - assert self.manager.get_original_value(encrypted1, session1) is None + assert self.manager.get_original_value(session1, encrypted1) is None # Verify session2 secrets remain assert self.manager.get_by_session_id(session2) is not None - assert self.manager.get_original_value(encrypted2, session2) == "secret2" + assert self.manager.get_original_value(session2, encrypted2) == "secret2" def test_cleanup(self): """Test that cleanup properly wipes all data""" @@ -69,49 +73,44 @@ def test_cleanup(self): self.manager.cleanup() # Verify all data is wiped - assert len(self.manager._session_store) == 0 - assert len(self.manager._encrypted_to_session) == 0 + assert len(self.manager.session_store.sessions) == 0 def test_multiple_secrets_same_session(self): """Test storing multiple secrets in the same session""" # Store multiple secrets in same session - encrypted1 = self.manager.store_secret("secret1", "service1", "type1", self.test_session) - encrypted2 = self.manager.store_secret("secret2", "service2", "type2", self.test_session) + encrypted1 = self.manager.store_secret(self.test_session, "secret1", "service1", "type1") + encrypted2 = self.manager.store_secret(self.test_session, "secret2", "service2", "type2") # Latest secret should be retrievable in the session stored = self.manager.get_by_session_id(self.test_session) assert isinstance(stored, dict) - assert stored[encrypted1].original == "secret1" - assert stored[encrypted2].original == "secret2" + assert stored[encrypted1]["original"] == "secret1" + assert stored[encrypted2]["original"] == "secret2" # Both secrets should be retrievable directly - assert self.manager.get_original_value(encrypted1, self.test_session) == "secret1" - assert self.manager.get_original_value(encrypted2, self.test_session) == "secret2" - - # Both encrypted values should map to the session - assert self.manager._encrypted_to_session[encrypted1] == self.test_session - assert self.manager._encrypted_to_session[encrypted2] == self.test_session + assert self.manager.get_original_value(self.test_session, encrypted1) == "secret1" + assert self.manager.get_original_value(self.test_session, encrypted2) == "secret2" def test_error_handling(self): """Test error handling in secret operations""" # Test with None values with pytest.raises(ValueError): - self.manager.store_secret(None, self.test_service, self.test_type, self.test_session) + self.manager.store_secret(self.test_session, None, self.test_service, self.test_type) with pytest.raises(ValueError): - self.manager.store_secret(self.test_value, None, self.test_type, self.test_session) + self.manager.store_secret(self.test_session, self.test_value, None, self.test_type) with pytest.raises(ValueError): - self.manager.store_secret(self.test_value, self.test_service, None, self.test_session) + self.manager.store_secret(self.test_session, self.test_value, self.test_service, None) with pytest.raises(ValueError): - self.manager.store_secret(self.test_value, self.test_service, self.test_type, None) + self.manager.store_secret(None, self.test_value, self.test_service, self.test_type) def test_secure_cleanup(self): """Test that cleanup securely wipes sensitive data""" # Store a secret self.manager.store_secret( - self.test_value, self.test_service, self.test_type, self.test_session + self.test_session, self.test_value, self.test_service, self.test_type ) # Get reference to stored data before cleanup @@ -121,10 +120,7 @@ def test_secure_cleanup(self): # Perform cleanup self.manager.cleanup() - # Verify the original string was overwritten, not just removed - # This test is a bit tricky since Python strings are immutable, - # but we can at least verify the data is no longer accessible - assert self.test_value not in str(self.manager._session_store) + assert len(self.manager.session_store.sessions) == 0 def test_session_isolation(self): """Test that sessions are properly isolated""" @@ -132,13 +128,13 @@ def test_session_isolation(self): session2 = "session2" # Store secrets in different sessions - encrypted1 = self.manager.store_secret("secret1", "service1", "type1", session1) - encrypted2 = self.manager.store_secret("secret2", "service2", "type2", session2) + encrypted1 = self.manager.store_secret(session1, "secret1", "service1", "type1") + encrypted2 = self.manager.store_secret(session2, "secret2", "service2", "type2") # Verify cross-session access is not possible - assert self.manager.get_original_value(encrypted1, session2) is None - assert self.manager.get_original_value(encrypted2, session1) is None + assert self.manager.get_original_value(session2, encrypted1) is None + assert self.manager.get_original_value(session1, encrypted2) is None # Verify correct session access works - assert self.manager.get_original_value(encrypted1, session1) == "secret1" - assert self.manager.get_original_value(encrypted2, session2) == "secret2" + assert self.manager.get_original_value(session1, encrypted1) == "secret1" + assert self.manager.get_original_value(session2, encrypted2) == "secret2" diff --git a/tests/pipeline/secrets/test_secrets.py b/tests/pipeline/secrets/test_secrets.py index 759b94b0..e5995a93 100644 --- a/tests/pipeline/secrets/test_secrets.py +++ b/tests/pipeline/secrets/test_secrets.py @@ -92,7 +92,7 @@ def test_hide_secret(self): # Verify the secret was stored encrypted_value = hidden[len("REDACTED<$") : -1] - original = self.secrets_manager.get_original_value(encrypted_value, self.session_id) + original = self.secrets_manager.get_original_value(self.session_id, encrypted_value) assert original == "AKIAIOSFODNN7EXAMPLE" def test_obfuscate(self): @@ -185,7 +185,7 @@ async def test_complete_marker_processing(self): """Test processing of a complete REDACTED marker""" # Store a secret encrypted = self.secrets_manager.store_secret( - "secret_value", "test_service", "api_key", self.session_id + self.session_id, "secret_value", "test_service", "api_key" ) # Add content with REDACTED marker to buffer From 444a6b1ffedb23d1c27e1fca04d8843ddd5ded88 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Fri, 21 Feb 2025 11:21:34 +0100 Subject: [PATCH 3/8] unify interface in sensitive data --- src/codegate/cli.py | 6 +- src/codegate/pipeline/base.py | 16 +- src/codegate/pipeline/factory.py | 10 +- src/codegate/pipeline/pii/analyzer.py | 69 +-------- src/codegate/pipeline/pii/manager.py | 85 ----------- src/codegate/pipeline/pii/pii.py | 141 +++++++++++++++--- src/codegate/pipeline/secrets/manager.py | 63 -------- src/codegate/pipeline/secrets/secrets.py | 49 +++--- .../pipeline/sensitive_data/manager.py | 56 +++++++ .../sensitive_data}/session_store.py | 0 src/codegate/providers/copilot/provider.py | 4 +- tests/pipeline/pii/test_analyzer.py | 36 ----- tests/pipeline/pii/test_pi.py | 60 +------- tests/pipeline/pii/test_pii_manager.py | 114 -------------- tests/pipeline/secrets/test_manager.py | 140 ----------------- tests/pipeline/secrets/test_secrets.py | 42 +++--- tests/test_server.py | 12 +- 17 files changed, 252 insertions(+), 651 deletions(-) delete mode 100644 src/codegate/pipeline/pii/manager.py delete mode 100644 src/codegate/pipeline/secrets/manager.py create mode 100644 src/codegate/pipeline/sensitive_data/manager.py rename src/codegate/{session => pipeline/sensitive_data}/session_store.py (100%) delete mode 100644 tests/pipeline/pii/test_pii_manager.py delete mode 100644 tests/pipeline/secrets/test_manager.py diff --git a/src/codegate/cli.py b/src/codegate/cli.py index be5096f6..455d9001 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -16,7 +16,7 @@ from codegate.config import Config, ConfigurationError from codegate.db.connection import init_db_sync, init_session_if_not_exists from codegate.pipeline.factory import PipelineFactory -from codegate.pipeline.secrets.manager import SecretsManager +from codegate.pipeline.sensitive_data.manager import SensitiveDataManager from codegate.providers import crud as provendcrud from codegate.providers.copilot.provider import CopilotProvider from codegate.server import init_app @@ -331,8 +331,8 @@ def serve( # noqa: C901 click.echo("Existing Certificates are already present.") # Initialize secrets manager and pipeline factory - secrets_manager = SecretsManager() - pipeline_factory = PipelineFactory(secrets_manager) + sensitive_data_manager = SensitiveDataManager() + pipeline_factory = PipelineFactory(sensitive_data_manager) app = init_app(pipeline_factory) diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index f9ce39b6..29c60b62 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -12,14 +12,14 @@ from codegate.clients.clients import ClientType from codegate.db.models import Alert, AlertSeverity, Output, Prompt from codegate.extract_snippets.message_extractor import CodeSnippet -from codegate.pipeline.secrets.manager import SecretsManager +from codegate.pipeline.sensitive_data.manager import SensitiveDataManager logger = structlog.get_logger("codegate") @dataclass class PipelineSensitiveData: - manager: SecretsManager + manager: SensitiveDataManager session_id: str api_key: Optional[str] = None model: Optional[str] = None @@ -266,19 +266,19 @@ class InputPipelineInstance: def __init__( self, pipeline_steps: List[PipelineStep], - secret_manager: SecretsManager, + sensitive_data_manager: SensitiveDataManager, is_fim: bool, client: ClientType = ClientType.GENERIC, ): self.pipeline_steps = pipeline_steps - self.secret_manager = secret_manager + self.sensitive_data_manager = sensitive_data_manager self.is_fim = is_fim self.context = PipelineContext(client=client) # we create the sesitive context here so that it is not shared between individual requests # TODO: could we get away with just generating the session ID for an instance? self.context.sensitive = PipelineSensitiveData( - manager=self.secret_manager, + manager=self.sensitive_data_manager, session_id=str(uuid.uuid4()), ) self.context.metadata["is_fim"] = is_fim @@ -335,12 +335,12 @@ class SequentialPipelineProcessor: def __init__( self, pipeline_steps: List[PipelineStep], - secret_manager: SecretsManager, + sensitive_data_manager: SensitiveDataManager, client_type: ClientType, is_fim: bool, ): self.pipeline_steps = pipeline_steps - self.secret_manager = secret_manager + self.sensitive_data_manager = sensitive_data_manager self.is_fim = is_fim self.instance = self._create_instance(client_type) @@ -348,7 +348,7 @@ def _create_instance(self, client_type: ClientType) -> InputPipelineInstance: """Create a new pipeline instance for processing a request""" return InputPipelineInstance( self.pipeline_steps, - self.secret_manager, + self.sensitive_data_manager, self.is_fim, client_type, ) diff --git a/src/codegate/pipeline/factory.py b/src/codegate/pipeline/factory.py index acde51b4..3a2e3479 100644 --- a/src/codegate/pipeline/factory.py +++ b/src/codegate/pipeline/factory.py @@ -12,18 +12,18 @@ PiiRedactionNotifier, PiiUnRedactionStep, ) -from codegate.pipeline.secrets.manager import SecretsManager from codegate.pipeline.secrets.secrets import ( CodegateSecrets, SecretRedactionNotifier, SecretUnredactionStep, ) +from codegate.pipeline.sensitive_data.manager import SensitiveDataManager from codegate.pipeline.system_prompt.codegate import SystemPrompt class PipelineFactory: - def __init__(self, secrets_manager: SecretsManager): - self.secrets_manager = secrets_manager + def __init__(self, sensitive_data_manager: SensitiveDataManager): + self.sensitive_data_manager = sensitive_data_manager def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelineProcessor: input_steps: List[PipelineStep] = [ @@ -41,7 +41,7 @@ def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelinePr ] return SequentialPipelineProcessor( input_steps, - self.secrets_manager, + self.sensitive_data_manager, client_type, is_fim=False, ) @@ -53,7 +53,7 @@ def create_fim_pipeline(self, client_type: ClientType) -> SequentialPipelineProc ] return SequentialPipelineProcessor( fim_steps, - self.secrets_manager, + self.sensitive_data_manager, client_type, is_fim=True, ) diff --git a/src/codegate/pipeline/pii/analyzer.py b/src/codegate/pipeline/pii/analyzer.py index 041cbc01..96442824 100644 --- a/src/codegate/pipeline/pii/analyzer.py +++ b/src/codegate/pipeline/pii/analyzer.py @@ -1,5 +1,4 @@ -import uuid -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, List, Optional import structlog from presidio_analyzer import AnalyzerEngine @@ -7,7 +6,7 @@ from codegate.db.models import AlertSeverity from codegate.pipeline.base import PipelineContext -from codegate.session.session_store import SessionStore +from codegate.pipeline.sensitive_data.session_store import SessionStore logger = structlog.get_logger("codegate.pii.analyzer") @@ -69,9 +68,7 @@ def __init__(self): PiiAnalyzer._instance = self - def analyze( - self, session_id: str, text: str, context: Optional[PipelineContext] = None - ) -> Tuple[str, List[Dict[str, Any]]]: + def analyze(self, text: str, context: Optional[PipelineContext] = None) -> List: # Prioritize credit card detection first entities = [ "PHONE_NUMBER", @@ -95,65 +92,7 @@ def analyze( language="en", score_threshold=0.3, # Lower threshold to catch more potential matches ) - - # Track found PII - found_pii = [] - - # Only anonymize if PII was found - if analyzer_results: - # Log each found PII instance and anonymize - anonymized_text = text - for result in analyzer_results: - pii_value = text[result.start : result.end] - uuid_placeholder = self.session_store.add_mapping(session_id, pii_value) - pii_info = { - "type": result.entity_type, - "value": pii_value, - "score": result.score, - "start": result.start, - "end": result.end, - "uuid_placeholder": uuid_placeholder, - } - found_pii.append(pii_info) - anonymized_text = anonymized_text.replace(pii_value, uuid_placeholder) - - # Log each PII detection with its UUID mapping - logger.info( - "PII detected and mapped", - pii_type=result.entity_type, - score=f"{result.score:.2f}", - uuid=uuid_placeholder, - # Don't log the actual PII value for security - value_length=len(pii_value), - session_id=session_id, - ) - - # Log summary of all PII found in this analysis - if found_pii and context: - # Create notification string for alert - notify_string = ( - f"**PII Detected** 🔒\n" - f"- Total PII Found: {len(found_pii)}\n" - f"- Types Found: {', '.join(set(p['type'] for p in found_pii))}\n" - ) - context.add_alert( - self._name, - trigger_string=notify_string, - severity_category=AlertSeverity.CRITICAL, - ) - - logger.info( - "PII analysis complete", - total_pii_found=len(found_pii), - pii_types=[p["type"] for p in found_pii], - session_id=session_id, - ) - - # Return the anonymized text, PII details, and session store - return anonymized_text, found_pii - - # If no PII found, return original text, empty list, and session store - return text, [] + return analyzer_results def restore_pii(self, session_id: str, anonymized_text: str) -> str: """ diff --git a/src/codegate/pipeline/pii/manager.py b/src/codegate/pipeline/pii/manager.py deleted file mode 100644 index 0b847bee..00000000 --- a/src/codegate/pipeline/pii/manager.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -import structlog - -from codegate.pipeline.base import PipelineContext -from codegate.pipeline.pii.analyzer import PiiAnalyzer -from codegate.session.session_store import SessionStore - -logger = structlog.get_logger("codegate") - - -class PiiManager: - """ - Manages the analysis and restoration of Personally Identifiable Information - (PII) in text. - - Attributes: - analyzer (PiiAnalyzer): The singleton instance of PiiAnalyzer used for - PII detection and restoration. - session_store (SessionStore): The session store for the current PII session. - - Methods: - __init__(): - Initializes the PiiManager with the singleton PiiAnalyzer instance and sets the - session store. - - analyze(session_id: str, text: str) -> Tuple[str, List[Dict[str, Any]]]: - Analyzes the given text for PII, anonymizes it, and logs the detected PII details. - Args: - session_id (str): The session id to store the PII. - text (str): The text to be analyzed for PII. - Returns: - Tuple[str, List[Dict[str, Any]]]: A tuple containing the anonymized text and - a list of found PII details. - - restore_pii(session_id: str, anonymized_text: st ) -> str: - Restores the PII in the given anonymized text using the current session. - Args: - session_id (str): The session id for the PII to be restored. - anonymized_text (str): The text with anonymized PII to be restored. - Returns: - str: The text with restored PII. - """ - - def __init__(self): - """ - Initialize the PiiManager with the singleton PiiAnalyzer instance. - """ - self.analyzer = PiiAnalyzer.get_instance() - # Always use the analyzer's session store - self._session_store = self.analyzer.session_store - - @property - def session_store(self) -> SessionStore: - """Get the current session store.""" - # Always return the analyzer's current session store - return self.analyzer.session_store - - def analyze( - self, session_id: str, text: str, context: Optional[PipelineContext] = None - ) -> Tuple[str, List[Dict[str, Any]]]: - # Call analyzer and get results - anonymized_text, found_pii = self.analyzer.analyze(session_id, text, context=context) - - # Log found PII details (without modifying the found_pii list) - if found_pii: - for pii in found_pii: - logger.info( - "PII detected", - pii_type=pii["type"], - value="*" * len(pii["value"]), # Don't log actual value - score=f"{pii['score']:.2f}", - ) - - # Return the exact same objects we got from the analyzer - return anonymized_text, found_pii - - def restore_pii(self, session_id: str, anonymized_text: str) -> str: - """ - Restore PII in the given anonymized text using the current session. - """ - if not session_id: - return anonymized_text - # Use the analyzer's restore_pii method with the current session store - return self.analyzer.restore_pii(session_id, anonymized_text) diff --git a/src/codegate/pipeline/pii/pii.py b/src/codegate/pipeline/pii/pii.py index 0f7b0e55..bfacc8a8 100644 --- a/src/codegate/pipeline/pii/pii.py +++ b/src/codegate/pipeline/pii/pii.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Optional +import re +from typing import Any, Dict, List, Optional, Tuple import uuid import regex as re @@ -7,13 +8,15 @@ from litellm.types.utils import Delta, StreamingChoices from codegate.config import Config +from codegate.db.models import AlertSeverity from codegate.pipeline.base import ( PipelineContext, PipelineResult, PipelineStep, ) from codegate.pipeline.output import OutputPipelineContext, OutputPipelineStep -from codegate.pipeline.pii.manager import PiiManager +from codegate.pipeline.pii.analyzer import PiiAnalyzer +from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager from codegate.pipeline.systemmsg import add_or_update_system_message logger = structlog.get_logger("codegate") @@ -26,7 +29,7 @@ class CodegatePii(PipelineStep): Methods: __init__: - Initializes the CodegatePii pipeline step and sets up the PiiManager. + Initializes the CodegatePii pipeline step and sets up the SensitiveDataManager. name: Returns the name of the pipeline step. @@ -39,13 +42,14 @@ class CodegatePii(PipelineStep): anonymized text and stores PII details in the context metadata. restore_pii(session_id: str, anonymized_text: str) -> str: - Restores the original PII from the anonymized text using the PiiManager. + Restores the original PII from the anonymized text using the SensitiveDataManager. """ def __init__(self): """Initialize the CodegatePii pipeline step.""" super().__init__() - self.pii_manager = PiiManager() + self.sensitive_data_manager = SensitiveDataManager() + self.analyzer = PiiAnalyzer.get_instance() @property def name(self) -> str: @@ -66,6 +70,68 @@ def _get_redacted_snippet(self, message: str, pii_details: List[Dict[str, Any]]) return message[start:end] + def process_results( + self, session_id: str, text: str, results: List, context: PipelineContext + ) -> Tuple[List, str]: + # Track found PII + found_pii = [] + + # Log each found PII instance and anonymize + anonymized_text = text + for result in results: + pii_value = text[result.start : result.end] + + # add to session store + obj = SensitiveData(pii_value, "pii", result.entity_type) + uuid_placeholder = self.sensitive_data_manager.store(session_id, obj) + anonymized_text = anonymized_text.replace(pii_value, uuid_placeholder) + + # Add to found PII list + pii_info = { + "type": result.entity_type, + "value": pii_value, + "score": result.score, + "start": result.start, + "end": result.end, + "uuid_placeholder": uuid_placeholder, + } + found_pii.append(pii_info) + + # Log each PII detection with its UUID mapping + logger.info( + "PII detected and mapped", + pii_type=result.entity_type, + score=f"{result.score:.2f}", + uuid=uuid_placeholder, + # Don't log the actual PII value for security + value_length=len(pii_value), + session_id=session_id, + ) + + # Log summary of all PII found in this analysis + if found_pii and context: + # Create notification string for alert + notify_string = ( + f"**PII Detected** 🔒\n" + f"- Total PII Found: {len(found_pii)}\n" + f"- Types Found: {', '.join(set(p['type'] for p in found_pii))}\n" + ) + context.add_alert( + self.name, + trigger_string=notify_string, + severity_category=AlertSeverity.CRITICAL, + ) + + logger.info( + "PII analysis complete", + total_pii_found=len(found_pii), + pii_types=[p["type"] for p in found_pii], + session_id=session_id, + ) + + # Return the anonymized text, PII details, and session store + return found_pii, anonymized_text + async def process( self, request: ChatCompletionRequest, context: PipelineContext ) -> PipelineResult: @@ -82,20 +148,22 @@ async def process( if "content" in message and message["content"]: # This is where analyze and anonymize the text original_text = str(message["content"]) - anonymized_text, pii_details = self.pii_manager.analyze( - session_id, original_text, context - ) - - if pii_details: - total_pii_found += len(pii_details) - all_pii_details.extend(pii_details) - new_request["messages"][i]["content"] = anonymized_text - - # If this is a user message, grab the redacted snippet! - if message.get("role") == "user": - last_redacted_text = self._get_redacted_snippet( - anonymized_text, pii_details - ) + results = self.analyzer.analyze(original_text, context) + if results: + pii_details, anonymized_text = self.process_results( + session_id, original_text, results, context + ) + + if pii_details: + total_pii_found += len(pii_details) + all_pii_details.extend(pii_details) + new_request["messages"][i]["content"] = anonymized_text + + # If this is a user message, grab the redacted snippet! + if message.get("role") == "user": + last_redacted_text = self._get_redacted_snippet( + anonymized_text, pii_details + ) logger.info(f"Total PII instances redacted: {total_pii_found}") @@ -106,7 +174,7 @@ async def process( context.metadata["session_id"] = session_id if total_pii_found > 0: - context.metadata["pii_manager"] = self.pii_manager + context.metadata["sensitive_data_manager"] = self.sensitive_data_manager system_message = ChatCompletionSystemMessage( content=Config.get_config().prompts.pii_redacted, @@ -119,7 +187,30 @@ async def process( return PipelineResult(request=new_request, context=context) def restore_pii(self, session_id: str, anonymized_text: str) -> str: - return self.pii_manager.restore_pii(session_id, anonymized_text) + """ + Restore the original PII (Personally Identifiable Information) in the given anonymized text. + + This method replaces placeholders in the anonymized text with their corresponding original + PII values using the mappings stored in the provided SessionStore. + + Args: + anonymized_text (str): The text containing placeholders for PII. + session_id (str): The session id containing mappings of placeholders + to original PII. + + Returns: + str: The text with the original PII restored. + """ + session_data = self.sensitive_data_manager.get_by_session_id(session_id) + if not session_data: + logger.warning( + "No active PII session found for given session ID. Unable to restore PII." + ) + return anonymized_text + + for uuid_placeholder, original_pii in session_data.items(): + anonymized_text = anonymized_text.replace(uuid_placeholder, original_pii) + return anonymized_text class PiiUnRedactionStep(OutputPipelineStep): @@ -204,11 +295,13 @@ async def process_chunk( # noqa: C901 if self._is_complete_uuid(uuid_value): # Get the PII manager from context metadata logger.debug(f"Valid UUID found: {uuid_value}") - pii_manager = input_context.metadata.get("pii_manager") if input_context else None - if pii_manager and pii_manager.session_store: + sensitive_data_manager = ( + input_context.metadata.get("sensitive_data_manager") if input_context else None + ) + if sensitive_data_manager and sensitive_data_manager.session_store: # Restore original value from PII manager logger.debug("Attempting to restore PII from UUID marker") - original = pii_manager.session_store.get_mapping(session_id, uuid_marker) + original = sensitive_data_manager.get_original_value(session_id, uuid_marker) logger.debug(f"Restored PII: {original}") result.append(original) else: diff --git a/src/codegate/pipeline/secrets/manager.py b/src/codegate/pipeline/secrets/manager.py deleted file mode 100644 index 9476eef9..00000000 --- a/src/codegate/pipeline/secrets/manager.py +++ /dev/null @@ -1,63 +0,0 @@ -import json -from typing import Dict, Optional - -import structlog - -from codegate.session.session_store import SessionStore - -logger = structlog.get_logger("codegate") - - -class SecretsManager: - """Manages encryption, storage and retrieval of secrets""" - - def __init__(self): - self.session_store = SessionStore() - - def store_secret(self, session_id: str, value: str, service: str, secret_type: str) -> str: - """ - Encrypts and stores a secret value. - Returns the encrypted value. - """ - if not session_id: - raise ValueError("Session id must be provided") - - if not value: - raise ValueError("Value must be provided") - if not service: - raise ValueError("Service must be provided") - if not secret_type: - raise ValueError("Secret type must be provided") - - uuid_placeholder = self.session_store.add_mapping( - session_id, - json.dumps({"original": value, "service": service, "secret_type": secret_type}), - ) - logger.debug( - "Stored secret", service=service, type=secret_type, placeholder=uuid_placeholder - ) - return uuid_placeholder - - def get_by_session_id(self, session_id: str) -> Optional[Dict]: - session_data = self.session_store.get_by_session_id(session_id) - if not session_data: - return None - # Convert all string values to dictionary objects using json.loads - return { - key: json.loads(value) if isinstance(value, str) else value - for key, value in session_data.items() - } - - def get_original_value(self, session_id: str, uuid_placeholder: str) -> Optional[str]: - """Retrieve original value for an encrypted value""" - secret_entry_json = self.session_store.get_mapping(session_id, uuid_placeholder) - if secret_entry_json: - secret_entry = json.loads(secret_entry_json) - return secret_entry.get("original") - return None - - def cleanup_session(self, session_id): - self.session_store.cleanup_session(session_id) - - def cleanup(self): - self.session_store.cleanup() diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index 68973a6f..c0b83085 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -16,8 +16,8 @@ PipelineStep, ) from codegate.pipeline.output import OutputPipelineContext, OutputPipelineStep -from codegate.pipeline.secrets.manager import SecretsManager from codegate.pipeline.secrets.signatures import CodegateSignatures, Match +from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager from codegate.pipeline.systemmsg import add_or_update_system_message logger = structlog.get_logger("codegate") @@ -171,11 +171,11 @@ def obfuscate(self, text: str, snippet: Optional[CodeSnippet]) -> tuple[str, Lis class SecretsEncryptor(SecretsModifier): def __init__( self, - secrets_manager: SecretsManager, + sensitive_data_manager: SensitiveDataManager, context: PipelineContext, session_id: str, ): - self._secrets_manager = secrets_manager + self._sensitive_data_manager = sensitive_data_manager self._session_id = session_id self._context = context self._name = "codegate-secrets" @@ -184,13 +184,22 @@ def __init__( def _hide_secret(self, match: Match) -> str: # Encrypt and store the value - encrypted_value = self._secrets_manager.store_secret( - self._session_id, - match.value, - match.service, - match.type, + if not self._session_id: + raise ValueError("Session id must be provided") + + if not match.value: + raise ValueError("Value must be provided") + if not match.service: + raise ValueError("Service must be provided") + if not match.type: + raise ValueError("Secret type must be provided") + + obj = SensitiveData(match.value, match.service, match.type) + uuid_placeholder = self._sensitive_data_manager.store(self._session_id, obj) + logger.debug( + "Stored secret", service=match.service, type=match.type, placeholder=uuid_placeholder ) - return f"REDACTED<${encrypted_value}>" + return f"REDACTED<{uuid_placeholder}>" def _notify_secret( self, match: Match, code_snippet: Optional[CodeSnippet], protected_text: List[str] @@ -252,7 +261,7 @@ def _redact_text( self, text: str, snippet: Optional[CodeSnippet], - secrets_manager: SecretsManager, + sensitive_data_manager: SensitiveDataManager, session_id: str, context: PipelineContext, ) -> tuple[str, List[Match]]: @@ -261,14 +270,14 @@ def _redact_text( Args: text: The text to protect - secrets_manager: .. + sensitive_data_manager: .. session_id: .. context: The pipeline context to be able to log alerts Returns: Tuple containing protected text with encrypted values and the count of redacted secrets """ # Find secrets in the text - text_encryptor = SecretsEncryptor(secrets_manager, context, session_id) + text_encryptor = SecretsEncryptor(sensitive_data_manager, context, session_id) return text_encryptor.obfuscate(text, snippet) async def process( @@ -288,8 +297,10 @@ async def process( if "messages" not in request: return PipelineResult(request=request, context=context) - secrets_manager = context.sensitive.manager - if not secrets_manager or not isinstance(secrets_manager, SecretsManager): + sensitive_data_manager = context.sensitive.manager + if not sensitive_data_manager or not isinstance( + sensitive_data_manager, SensitiveDataManager + ): raise ValueError("Secrets manager not found in context") session_id = context.sensitive.session_id if not session_id: @@ -306,7 +317,7 @@ async def process( for i, message in enumerate(new_request["messages"]): if "content" in message and message["content"]: redacted_content, secrets_matched = self._redact_message_content( - message["content"], secrets_manager, session_id, context + message["content"], sensitive_data_manager, session_id, context ) new_request["messages"][i]["content"] = redacted_content if i > last_assistant_idx: @@ -314,7 +325,7 @@ async def process( new_request = self._finalize_redaction(context, total_matches, new_request) return PipelineResult(request=new_request, context=context) - def _redact_message_content(self, message_content, secrets_manager, session_id, context): + def _redact_message_content(self, message_content, sensitive_data_manager, session_id, context): # Extract any code snippets extractor = MessageCodeExtractorFactory.create_snippet_extractor(context.client) snippets = extractor.extract_snippets(message_content) @@ -323,7 +334,7 @@ def _redact_message_content(self, message_content, secrets_manager, session_id, for snippet in snippets: redacted_snippet, secrets_matched = self._redact_text( - snippet, snippet, secrets_manager, session_id, context + snippet, snippet, sensitive_data_manager, session_id, context ) redacted_snippets[snippet.code] = redacted_snippet total_matches.extend(secrets_matched) @@ -337,7 +348,7 @@ def _redact_message_content(self, message_content, secrets_manager, session_id, if start_index > last_end: non_snippet_part = message_content[last_end:start_index] redacted_part, secrets_matched = self._redact_text( - non_snippet_part, "", secrets_manager, session_id, context + non_snippet_part, "", sensitive_data_manager, session_id, context ) non_snippet_parts.append(redacted_part) total_matches.extend(secrets_matched) @@ -348,7 +359,7 @@ def _redact_message_content(self, message_content, secrets_manager, session_id, if last_end < len(message_content): remaining_text = message_content[last_end:] redacted_remaining, secrets_matched = self._redact_text( - remaining_text, "", secrets_manager, session_id, context + remaining_text, "", sensitive_data_manager, session_id, context ) non_snippet_parts.append(redacted_remaining) total_matches.extend(secrets_matched) diff --git a/src/codegate/pipeline/sensitive_data/manager.py b/src/codegate/pipeline/sensitive_data/manager.py new file mode 100644 index 00000000..3bc80ef6 --- /dev/null +++ b/src/codegate/pipeline/sensitive_data/manager.py @@ -0,0 +1,56 @@ +import json +from typing import Dict, Optional +import structlog +from codegate.pipeline.sensitive_data.session_store import SessionStore + +logger = structlog.get_logger("codegate") + + +class SensitiveData: + """Represents sensitive data with additional metadata.""" + + def __init__(self, original: str, service: Optional[str], type: Optional[str]): + self.original = original + self.service = service + self.type = type + + def to_json(self) -> str: + """Serializes the object to a JSON string.""" + return json.dumps({key: value for key, value in vars(self).items() if value is not None}) + + @staticmethod + def from_json(data: str) -> "SensitiveData": + """Deserializes from a JSON string to a SensitiveData object.""" + obj = json.loads(data) + return SensitiveData(obj["original"], obj.get("service"), obj.get("type")) + + +class SensitiveDataManager: + """Manages encryption, storage, and retrieval of secrets""" + + def __init__(self): + self.session_store = SessionStore() + + def store(self, session_id: str, value: SensitiveData) -> Optional[str]: + if not session_id or not value.original: + return None + return self.session_store.add_mapping(session_id, value.to_json()) + + def get_by_session_id(self, session_id: str) -> Optional[Dict]: + if not session_id: + return None + data = self.session_store.get_by_session_id(session_id) + return SensitiveData.from_json(data) if data else None + + def get_original_value(self, session_id: str, uuid_placeholder: str) -> Optional[str]: + if not session_id: + return None + secret_entry_json = self.session_store.get_mapping(session_id, uuid_placeholder) + return SensitiveData.from_json(secret_entry_json).original if secret_entry_json else None + + def cleanup_session(self, session_id: str): + if session_id: + self.session_store.cleanup_session(session_id) + + def cleanup(self): + self.session_store.cleanup() diff --git a/src/codegate/session/session_store.py b/src/codegate/pipeline/sensitive_data/session_store.py similarity index 100% rename from src/codegate/session/session_store.py rename to src/codegate/pipeline/sensitive_data/session_store.py diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index b17e98a8..182f2731 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -17,7 +17,7 @@ from codegate.pipeline.base import PipelineContext from codegate.pipeline.factory import PipelineFactory from codegate.pipeline.output import OutputPipelineInstance -from codegate.pipeline.secrets.manager import SecretsManager +from codegate.pipeline.sensitive_data.manager import SensitiveDataManager from codegate.providers.copilot.mapping import PIPELINE_ROUTES, VALIDATED_ROUTES, PipelineType from codegate.providers.copilot.pipeline import ( CopilotChatPipeline, @@ -200,7 +200,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop): self.ca = CertificateAuthority.get_instance() self.cert_manager = TLSCertDomainManager(self.ca) self._closing = False - self.pipeline_factory = PipelineFactory(SecretsManager()) + self.pipeline_factory = PipelineFactory(SensitiveDataManager()) self.input_pipeline: Optional[CopilotPipeline] = None self.fim_pipeline: Optional[CopilotPipeline] = None # the context as provided by the pipeline diff --git a/tests/pipeline/pii/test_analyzer.py b/tests/pipeline/pii/test_analyzer.py index 618549df..d626b8cf 100644 --- a/tests/pipeline/pii/test_analyzer.py +++ b/tests/pipeline/pii/test_analyzer.py @@ -4,7 +4,6 @@ from presidio_analyzer import RecognizerResult from codegate.pipeline.pii.analyzer import PiiAnalyzer -from codegate.session.session_store import SessionStore class TestPiiAnalyzer: @@ -68,41 +67,6 @@ def test_singleton_pattern(self): with pytest.raises(RuntimeError, match="Use PiiAnalyzer.get_instance()"): PiiAnalyzer() - def test_analyze_no_pii(self, analyzer, mock_analyzer_engine): - text = "Hello world" - session_id = "session-id" - mock_analyzer_engine.analyze.return_value = [] - - result_text, found_pii = analyzer.analyze(session_id, text) - - assert result_text == text - assert found_pii == [] - - def test_analyze_with_pii(self, analyzer, mock_analyzer_engine): - text = "My email is test@example.com" - session_id = "session-id" - email_pii = RecognizerResult( - entity_type="EMAIL_ADDRESS", - start=12, - end=28, - score=1.0, # EmailRecognizer returns a score of 1.0 - ) - mock_analyzer_engine.analyze.return_value = [email_pii] - - result_text, found_pii = analyzer.analyze(session_id, text) - - assert len(found_pii) == 1 - pii_info = found_pii[0] - assert pii_info["type"] == "EMAIL_ADDRESS" - assert pii_info["value"] == "test@example.com" - assert pii_info["score"] == 1.0 - assert pii_info["start"] == 12 - assert pii_info["end"] == 28 - assert "uuid_placeholder" in pii_info - # Verify the placeholder was used to replace the PII - placeholder = pii_info["uuid_placeholder"] - assert result_text == f"My email is {placeholder}" - def test_restore_pii(self, analyzer): original_text = "test@example.com" session_id = "session-id" diff --git a/tests/pipeline/pii/test_pi.py b/tests/pipeline/pii/test_pi.py index 8fa52acf..ac0edb83 100644 --- a/tests/pipeline/pii/test_pi.py +++ b/tests/pipeline/pii/test_pi.py @@ -51,58 +51,6 @@ async def test_process_no_messages(self, pii_step): assert result.request == request assert result.context == context - @pytest.mark.asyncio - async def test_process_with_pii(self, pii_step): - original_text = "My email is test@example.com" - request = ChatCompletionRequest( - model="test-model", messages=[{"role": "user", "content": original_text}] - ) - context = PipelineContext() - - # Mock the PII manager's analyze method - placeholder = "" - pii_details = [ - { - "type": "EMAIL_ADDRESS", - "value": "test@example.com", - "score": 1.0, - "start": 12, - "end": 27, - "uuid_placeholder": placeholder, - } - ] - anonymized_text = f"My email is {placeholder}" - pii_step.pii_manager.analyze = MagicMock(return_value=(anonymized_text, pii_details)) - - result = await pii_step.process(request, context) - - # Verify the user message was anonymized - user_messages = [m for m in result.request["messages"] if m["role"] == "user"] - assert len(user_messages) == 1 - assert user_messages[0]["content"] == anonymized_text - - # Verify metadata was updated - assert result.context.metadata["redacted_pii_count"] == 1 - assert len(result.context.metadata["redacted_pii_details"]) == 1 - # The redacted text should be just the placeholder since that's what _get_redacted_snippet returns # noqa: E501 - assert result.context.metadata["redacted_text"] == placeholder - assert "pii_manager" in result.context.metadata - - # Verify system message was added - system_messages = [m for m in result.request["messages"] if m["role"] == "system"] - assert len(system_messages) == 1 - assert system_messages[0]["content"] == "PII has been redacted" - - def test_restore_pii(self, pii_step): - anonymized_text = "My email is " - original_text = "My email is test@example.com" - session_id = "session-id" - pii_step.pii_manager.restore_pii = MagicMock(return_value=original_text) - - restored = pii_step.restore_pii(session_id, anonymized_text) - - assert restored == original_text - class TestPiiUnRedactionStep: @pytest.fixture @@ -161,11 +109,9 @@ async def test_process_chunk_with_uuid(self, unredaction_step): input_context = PipelineContext(metadata={"session_id": "session-id"}) # Mock PII manager in input context - mock_pii_manager = MagicMock() - mock_session = MagicMock() - mock_session.get_mapping = MagicMock(return_value="test@example.com") - mock_pii_manager.session_store = mock_session - input_context.metadata["pii_manager"] = mock_pii_manager + mock_sensitive_data_manager = MagicMock() + mock_sensitive_data_manager.get_original_value = MagicMock(return_value="test@example.com") + input_context.metadata["sensitive_data_manager"] = mock_sensitive_data_manager result = await unredaction_step.process_chunk(chunk, context, input_context) assert result[0].choices[0].delta.content == "Text with test@example.com" diff --git a/tests/pipeline/pii/test_pii_manager.py b/tests/pipeline/pii/test_pii_manager.py deleted file mode 100644 index aa363240..00000000 --- a/tests/pipeline/pii/test_pii_manager.py +++ /dev/null @@ -1,114 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from codegate.pipeline.pii.analyzer import SessionStore -from codegate.pipeline.pii.manager import PiiManager - - -class TestPiiManager: - @pytest.fixture - def session_store(self): - """Create a session store that will be shared between the mock and manager""" - return SessionStore() - - @pytest.fixture - def mock_analyzer(self, session_store): - """Create a mock analyzer with the shared session store""" - mock_instance = MagicMock() - mock_instance.analyze = MagicMock() - mock_instance.restore_pii = MagicMock() - mock_instance.session_store = session_store - return mock_instance - - @pytest.fixture - def manager(self, mock_analyzer): - """Create a PiiManager instance with the mocked analyzer""" - with patch("codegate.pipeline.pii.manager.PiiAnalyzer") as mock_analyzer_class: - # Set up the mock class to return our mock instance - mock_analyzer_class.get_instance.return_value = mock_analyzer - # Create the manager which will use our mock - return PiiManager() - - def test_init(self, manager, mock_analyzer): - assert manager.session_store is mock_analyzer.session_store - assert manager.analyzer is mock_analyzer - - def test_analyze_no_pii(self, manager, mock_analyzer): - text = "Hello CodeGate" - session_id = "session-id" - session_store = mock_analyzer.session_store - mock_analyzer.analyze.return_value = (text, []) - - anonymized_text, found_pii = manager.analyze(session_id, text) - - assert anonymized_text == text - assert found_pii == [] - assert manager.session_store is session_store - mock_analyzer.analyze.assert_called_once_with(session_id, text, context=None) - - def test_analyze_with_pii(self, manager, mock_analyzer): - text = "My email is test@example.com" - session_id = "session-id" - session_store = mock_analyzer.session_store - placeholder = "" - pii_details = [ - { - "type": "EMAIL_ADDRESS", - "value": "test@example.com", - "score": 0.85, - "start": 12, - "end": 28, # Fixed end position - "uuid_placeholder": placeholder, - } - ] - anonymized_text = f"My email is {placeholder}" - mock_analyzer.analyze.return_value = (anonymized_text, pii_details) - session_store.sessions[session_id] = {placeholder: "test@example.com"} - - result_text, found_pii = manager.analyze(session_id, text) - - assert "My email is <" in result_text - assert ">" in result_text - assert found_pii == pii_details - assert manager.session_store is session_store - - assert manager.session_store.sessions[session_id][placeholder] == "test@example.com" - mock_analyzer.analyze.assert_called_once_with(session_id, text, context=None) - - def test_restore_pii_no_session(self, manager): - text = "Anonymized text" - session_id = "" - restored_text = manager.restore_pii(session_id, text) - - assert restored_text == text - - def test_restore_pii_with_session(self, manager, mock_analyzer): - anonymized_text = "My email is #test-uuid#" - original_text = "My email is test@example.com" - session_id = "session-id" - session_store = mock_analyzer.session_store - session_store.sessions[session_id] = {"#test-uuid#": "test@example.com"} - mock_analyzer.restore_pii.return_value = original_text - - restored_text = manager.restore_pii(session_id, anonymized_text) - - assert restored_text == original_text - mock_analyzer.restore_pii.assert_called_once_with(session_id, anonymized_text) - - def test_restore_pii_multiple_placeholders(self, manager, mock_analyzer): - anonymized_text = "Email: #uuid1#, Phone: #uuid2#" - original_text = "Email: test@example.com, Phone: 123-456-7890" - session_id = "session-id" - session_store = mock_analyzer.session_store - session_store.sessions[session_id] = { - "#uuid1#": "test@example.com", - "#uuid2#": "123-456-7890", - } - - mock_analyzer.restore_pii.return_value = original_text - - restored_text = manager.restore_pii(session_id, anonymized_text) - - assert restored_text == original_text - mock_analyzer.restore_pii.assert_called_once_with(session_id, anonymized_text) diff --git a/tests/pipeline/secrets/test_manager.py b/tests/pipeline/secrets/test_manager.py deleted file mode 100644 index d7d97787..00000000 --- a/tests/pipeline/secrets/test_manager.py +++ /dev/null @@ -1,140 +0,0 @@ -import pytest - -from codegate.pipeline.secrets.manager import SecretsManager - - -class TestSecretsManager: - def setup_method(self): - """Setup a fresh SecretsManager for each test""" - self.manager = SecretsManager() - self.test_session = "session-id" - self.test_value = "super_secret_value" - self.test_service = "test_service" - self.test_type = "api_key" - - def test_store_secret(self): - """Test basic secret storage and retrieval""" - # Store a secret - encrypted = self.manager.store_secret( - self.test_session, self.test_value, self.test_service, self.test_type - ) - - # Verify the secret was stored - stored = self.manager.get_by_session_id(self.test_session) - assert stored[encrypted]["original"] == self.test_value - - # Verify encrypted value can be retrieved - retrieved = self.manager.get_original_value(self.test_session, encrypted) - assert retrieved == self.test_value - - def test_get_original_value_wrong_session(self): - """Test that secrets can't be accessed with wrong session ID""" - encrypted = self.manager.store_secret( - self.test_session, self.test_value, self.test_service, self.test_type - ) - - # Try to retrieve with wrong session ID - wrong_session = "wrong_session_id" - retrieved = self.manager.get_original_value(wrong_session, encrypted) - assert retrieved is None - - def test_get_original_value_nonexistent(self): - """Test handling of non-existent encrypted values""" - retrieved = self.manager.get_original_value("nonexistent", self.test_session) - assert retrieved is None - - def test_cleanup_session(self): - """Test that session cleanup properly removes secrets""" - # Store multiple secrets in different sessions - session1 = "session1" - session2 = "session2" - - encrypted1 = self.manager.store_secret(session1, "secret1", "service1", "type1") - encrypted2 = self.manager.store_secret(session2, "secret2", "service2", "type2") - - # Clean up session1 - self.manager.cleanup_session(session1) - - # Verify session1 secrets are gone - assert self.manager.get_by_session_id(session1) is None - assert self.manager.get_original_value(session1, encrypted1) is None - - # Verify session2 secrets remain - assert self.manager.get_by_session_id(session2) is not None - assert self.manager.get_original_value(session2, encrypted2) == "secret2" - - def test_cleanup(self): - """Test that cleanup properly wipes all data""" - # Store multiple secrets - self.manager.store_secret("secret1", "service1", "type1", "session1") - self.manager.store_secret("secret2", "service2", "type2", "session2") - - # Perform cleanup - self.manager.cleanup() - - # Verify all data is wiped - assert len(self.manager.session_store.sessions) == 0 - - def test_multiple_secrets_same_session(self): - """Test storing multiple secrets in the same session""" - # Store multiple secrets in same session - encrypted1 = self.manager.store_secret(self.test_session, "secret1", "service1", "type1") - encrypted2 = self.manager.store_secret(self.test_session, "secret2", "service2", "type2") - - # Latest secret should be retrievable in the session - stored = self.manager.get_by_session_id(self.test_session) - assert isinstance(stored, dict) - assert stored[encrypted1]["original"] == "secret1" - assert stored[encrypted2]["original"] == "secret2" - - # Both secrets should be retrievable directly - assert self.manager.get_original_value(self.test_session, encrypted1) == "secret1" - assert self.manager.get_original_value(self.test_session, encrypted2) == "secret2" - - def test_error_handling(self): - """Test error handling in secret operations""" - # Test with None values - with pytest.raises(ValueError): - self.manager.store_secret(self.test_session, None, self.test_service, self.test_type) - - with pytest.raises(ValueError): - self.manager.store_secret(self.test_session, self.test_value, None, self.test_type) - - with pytest.raises(ValueError): - self.manager.store_secret(self.test_session, self.test_value, self.test_service, None) - - with pytest.raises(ValueError): - self.manager.store_secret(None, self.test_value, self.test_service, self.test_type) - - def test_secure_cleanup(self): - """Test that cleanup securely wipes sensitive data""" - # Store a secret - self.manager.store_secret( - self.test_session, self.test_value, self.test_service, self.test_type - ) - - # Get reference to stored data before cleanup - stored = self.manager.get_by_session_id(self.test_session) - assert len(stored) == 1 - - # Perform cleanup - self.manager.cleanup() - - assert len(self.manager.session_store.sessions) == 0 - - def test_session_isolation(self): - """Test that sessions are properly isolated""" - session1 = "session1" - session2 = "session2" - - # Store secrets in different sessions - encrypted1 = self.manager.store_secret(session1, "secret1", "service1", "type1") - encrypted2 = self.manager.store_secret(session2, "secret2", "service2", "type2") - - # Verify cross-session access is not possible - assert self.manager.get_original_value(session2, encrypted1) is None - assert self.manager.get_original_value(session1, encrypted2) is None - - # Verify correct session access works - assert self.manager.get_original_value(session1, encrypted1) == "secret1" - assert self.manager.get_original_value(session2, encrypted2) == "secret2" diff --git a/tests/pipeline/secrets/test_secrets.py b/tests/pipeline/secrets/test_secrets.py index e5995a93..0c471d4e 100644 --- a/tests/pipeline/secrets/test_secrets.py +++ b/tests/pipeline/secrets/test_secrets.py @@ -7,13 +7,13 @@ from codegate.pipeline.base import PipelineContext, PipelineSensitiveData from codegate.pipeline.output import OutputPipelineContext -from codegate.pipeline.secrets.manager import SecretsManager from codegate.pipeline.secrets.secrets import ( SecretsEncryptor, SecretsObfuscator, SecretUnredactionStep, ) from codegate.pipeline.secrets.signatures import CodegateSignatures, Match +from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager class TestSecretsModifier: @@ -69,9 +69,11 @@ class TestSecretsEncryptor: def setup(self, temp_yaml_file): CodegateSignatures.initialize(temp_yaml_file) self.context = PipelineContext() - self.secrets_manager = SecretsManager() + self.sensitive_data_manager = SensitiveDataManager() self.session_id = "test_session" - self.encryptor = SecretsEncryptor(self.secrets_manager, self.context, self.session_id) + self.encryptor = SecretsEncryptor( + self.sensitive_data_manager, self.context, self.session_id + ) def test_hide_secret(self): # Create a test match @@ -87,12 +89,12 @@ def test_hide_secret(self): # Test secret hiding hidden = self.encryptor._hide_secret(match) - assert hidden.startswith("REDACTED<$") + assert hidden.startswith("REDACTED<") assert hidden.endswith(">") # Verify the secret was stored - encrypted_value = hidden[len("REDACTED<$") : -1] - original = self.secrets_manager.get_original_value(self.session_id, encrypted_value) + encrypted_value = hidden[len("REDACTED<") : -1] + original = self.sensitive_data_manager.get_original_value(self.session_id, encrypted_value) assert original == "AKIAIOSFODNN7EXAMPLE" def test_obfuscate(self): @@ -101,7 +103,7 @@ def test_obfuscate(self): protected, matched_secrets = self.encryptor.obfuscate(text, None) assert len(matched_secrets) == 1 - assert "REDACTED<$" in protected + assert "REDACTED<" in protected assert "AKIAIOSFODNN7EXAMPLE" not in protected assert "Other text" in protected @@ -171,25 +173,24 @@ def setup_method(self): """Setup fresh instances for each test""" self.step = SecretUnredactionStep() self.context = OutputPipelineContext() - self.secrets_manager = SecretsManager() + self.sensitive_data_manager = SensitiveDataManager() self.session_id = "test_session" # Setup input context with secrets manager self.input_context = PipelineContext() self.input_context.sensitive = PipelineSensitiveData( - manager=self.secrets_manager, session_id=self.session_id + manager=self.sensitive_data_manager, session_id=self.session_id ) @pytest.mark.asyncio async def test_complete_marker_processing(self): """Test processing of a complete REDACTED marker""" # Store a secret - encrypted = self.secrets_manager.store_secret( - self.session_id, "secret_value", "test_service", "api_key" - ) + obj = SensitiveData("secret_value", "test_service", "api_key") + encrypted = self.sensitive_data_manager.store(self.session_id, obj) # Add content with REDACTED marker to buffer - self.context.buffer.append(f"Here is the REDACTED<${encrypted}> in text") + self.context.buffer.append(f"Here is the REDACTED<{encrypted}> in text") # Process a chunk result = await self.step.process_chunk( @@ -204,7 +205,7 @@ async def test_complete_marker_processing(self): async def test_partial_marker_buffering(self): """Test handling of partial REDACTED markers""" # Add partial marker to buffer - self.context.buffer.append("Here is REDACTED<$") + self.context.buffer.append("Here is REDACTED<") # Process a chunk result = await self.step.process_chunk( @@ -218,7 +219,7 @@ async def test_partial_marker_buffering(self): async def test_invalid_encrypted_value(self): """Test handling of invalid encrypted values""" # Add content with invalid encrypted value - self.context.buffer.append("Here is REDACTED<$invalid_value> in text") + self.context.buffer.append("Here is REDACTED in text") # Process chunk result = await self.step.process_chunk( @@ -227,7 +228,7 @@ async def test_invalid_encrypted_value(self): # Should keep the REDACTED marker for invalid values assert len(result) == 1 - assert result[0].choices[0].delta.content == "Here is REDACTED<$invalid_value> in text" + assert result[0].choices[0].delta.content == "Here is REDACTED in text" @pytest.mark.asyncio async def test_missing_context(self): @@ -271,12 +272,11 @@ async def test_no_markers(self): async def test_wrong_session(self): """Test unredaction with wrong session ID""" # Store secret with one session - encrypted = self.secrets_manager.store_secret( - "secret_value", "test_service", "api_key", "different_session" - ) + obj = SensitiveData("test_service", "api_key", "different_session") + encrypted = self.sensitive_data_manager.store("different_session", obj) # Try to unredact with different session - self.context.buffer.append(f"Here is the REDACTED<${encrypted}> in text") + self.context.buffer.append(f"Here is the REDACTED<{encrypted}> in text") result = await self.step.process_chunk( create_model_response("text"), self.context, self.input_context @@ -284,4 +284,4 @@ async def test_wrong_session(self): # Should keep REDACTED marker when session doesn't match assert len(result) == 1 - assert result[0].choices[0].delta.content == f"Here is the REDACTED<${encrypted}> in text" + assert result[0].choices[0].delta.content == f"Here is the REDACTED<{encrypted}> in text" diff --git a/tests/test_server.py b/tests/test_server.py index 1e06c096..aa549810 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -14,19 +14,13 @@ from codegate import __version__ from codegate.pipeline.factory import PipelineFactory -from codegate.pipeline.secrets.manager import SecretsManager +from codegate.pipeline.sensitive_data.manager import SensitiveDataManager from codegate.providers.registry import ProviderRegistry from codegate.server import init_app from src.codegate.cli import UvicornServer, cli from src.codegate.codegate_logging import LogFormat, LogLevel -@pytest.fixture -def mock_secrets_manager(): - """Create a mock secrets manager.""" - return MagicMock(spec=SecretsManager) - - @pytest.fixture def mock_provider_registry(): """Create a mock provider registry.""" @@ -96,9 +90,9 @@ def test_version_endpoint(mock_fetch_latest_version, test_client: TestClient) -> assert response_data["is_latest"] is False -@patch("codegate.pipeline.secrets.manager.SecretsManager") +@patch("codegate.pipeline.sensitive_data.manager.SensitiveDataManager") @patch("codegate.server.get_provider_registry") -def test_provider_registration(mock_registry, mock_secrets_mgr, mock_pipeline_factory) -> None: +def test_provider_registration(mock_registry, mock_pipeline_factory) -> None: """Test that all providers are registered correctly.""" init_app(mock_pipeline_factory) From e5f33ddeaba9cad080eeb7a8e33b52daa7dec78b Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Fri, 21 Feb 2025 11:41:15 +0100 Subject: [PATCH 4/8] add missing tests --- .../pipeline/sensitive_data/manager.py | 3 + tests/pipeline/sensitive_data/test_manager.py | 46 +++++++ .../sensitive_data/test_session_store.py | 114 ++++++++++++++++++ 3 files changed, 163 insertions(+) create mode 100644 tests/pipeline/sensitive_data/test_manager.py create mode 100644 tests/pipeline/sensitive_data/test_session_store.py diff --git a/src/codegate/pipeline/sensitive_data/manager.py b/src/codegate/pipeline/sensitive_data/manager.py index 3bc80ef6..08072bdc 100644 --- a/src/codegate/pipeline/sensitive_data/manager.py +++ b/src/codegate/pipeline/sensitive_data/manager.py @@ -32,8 +32,11 @@ def __init__(self): self.session_store = SessionStore() def store(self, session_id: str, value: SensitiveData) -> Optional[str]: + print("in store") if not session_id or not value.original: return None + print("i call add mapping") + print(self.session_store) return self.session_store.add_mapping(session_id, value.to_json()) def get_by_session_id(self, session_id: str) -> Optional[Dict]: diff --git a/tests/pipeline/sensitive_data/test_manager.py b/tests/pipeline/sensitive_data/test_manager.py new file mode 100644 index 00000000..44b2b857 --- /dev/null +++ b/tests/pipeline/sensitive_data/test_manager.py @@ -0,0 +1,46 @@ +import json +from unittest.mock import MagicMock, patch +import pytest +from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager +from codegate.pipeline.sensitive_data.session_store import SessionStore + + +class TestSensitiveDataManager: + @pytest.fixture + def mock_session_store(self): + """Mock the SessionStore instance used within SensitiveDataManager.""" + return MagicMock(spec=SessionStore) + + @pytest.fixture + def manager(self, mock_session_store): + """Patch SensitiveDataManager to use the mocked SessionStore.""" + with patch.object(SensitiveDataManager, "__init__", lambda self: None): + manager = SensitiveDataManager() + manager.session_store = mock_session_store # Manually inject the mock + return manager + + def test_store_success(self, manager, mock_session_store): + """Test storing a SensitiveData object successfully.""" + session_id = "session-123" + sensitive_data = SensitiveData("secret_value", "AWS", "API_KEY") + + # Mock session store behavior + mock_session_store.add_mapping.return_value = "uuid-123" + + result = manager.store(session_id, sensitive_data) + + # Verify correct function calls + mock_session_store.add_mapping.assert_called_once_with(session_id, sensitive_data.to_json()) + assert result == "uuid-123" + + def test_store_invalid_session_id(self, manager): + """Test storing data with an invalid session ID (should return None).""" + sensitive_data = SensitiveData("secret_value", "AWS", "API_KEY") + result = manager.store("", sensitive_data) # Empty session ID + assert result is None + + def test_store_missing_original_value(self, manager): + """Test storing data without an original value (should return None).""" + sensitive_data = SensitiveData(original="", service="AWS", type="API_KEY") # Empty original + result = manager.store("session-123", sensitive_data) + assert result is None diff --git a/tests/pipeline/sensitive_data/test_session_store.py b/tests/pipeline/sensitive_data/test_session_store.py new file mode 100644 index 00000000..b9ab64fe --- /dev/null +++ b/tests/pipeline/sensitive_data/test_session_store.py @@ -0,0 +1,114 @@ +import uuid +import pytest +from codegate.pipeline.sensitive_data.session_store import SessionStore + + +class TestSessionStore: + @pytest.fixture + def session_store(self): + """Fixture to create a fresh SessionStore instance before each test.""" + return SessionStore() + + def test_add_mapping_creates_uuid(self, session_store): + """Test that add_mapping correctly stores data and returns a UUID.""" + session_id = "session-123" + data = "test-data" + + uuid_placeholder = session_store.add_mapping(session_id, data) + + # Ensure the returned placeholder follows the expected format + assert uuid_placeholder.startswith("#") and uuid_placeholder.endswith("#") + assert len(uuid_placeholder) > 2 # Should have a UUID inside + + # Verify data is correctly stored + stored_data = session_store.get_mapping(session_id, uuid_placeholder) + assert stored_data == data + + def test_add_mapping_creates_unique_uuids(self, session_store): + """Ensure multiple calls to add_mapping generate unique UUIDs.""" + session_id = "session-123" + data1 = "data1" + data2 = "data2" + + uuid_placeholder1 = session_store.add_mapping(session_id, data1) + uuid_placeholder2 = session_store.add_mapping(session_id, data2) + + assert uuid_placeholder1 != uuid_placeholder2 # UUIDs must be unique + + # Ensure data is correctly stored + assert session_store.get_mapping(session_id, uuid_placeholder1) == data1 + assert session_store.get_mapping(session_id, uuid_placeholder2) == data2 + + def test_get_by_session_id(self, session_store): + """Test retrieving all stored mappings for a session ID.""" + session_id = "session-123" + data1 = "data1" + data2 = "data2" + + uuid1 = session_store.add_mapping(session_id, data1) + uuid2 = session_store.add_mapping(session_id, data2) + + stored_session_data = session_store.get_by_session_id(session_id) + + assert uuid1 in stored_session_data + assert uuid2 in stored_session_data + assert stored_session_data[uuid1] == data1 + assert stored_session_data[uuid2] == data2 + + def test_get_by_session_id_not_found(self, session_store): + """Test get_by_session_id when session does not exist (should return None).""" + session_id = "non-existent-session" + assert session_store.get_by_session_id(session_id) is None + + def test_get_mapping_success(self, session_store): + """Test retrieving a specific mapping.""" + session_id = "session-123" + data = "test-data" + + uuid_placeholder = session_store.add_mapping(session_id, data) + + assert session_store.get_mapping(session_id, uuid_placeholder) == data + + def test_get_mapping_not_found(self, session_store): + """Test retrieving a mapping that does not exist (should return None).""" + session_id = "session-123" + uuid_placeholder = "#non-existent-uuid#" + + assert session_store.get_mapping(session_id, uuid_placeholder) is None + + def test_cleanup_session(self, session_store): + """Test that cleanup_session removes all data for a session ID.""" + session_id = "session-123" + session_store.add_mapping(session_id, "test-data") + + # Ensure session exists before cleanup + assert session_store.get_by_session_id(session_id) is not None + + session_store.cleanup_session(session_id) + + # Ensure session is removed after cleanup + assert session_store.get_by_session_id(session_id) is None + + def test_cleanup_session_non_existent(self, session_store): + """Test cleanup_session on a non-existent session (should not raise errors).""" + session_id = "non-existent-session" + session_store.cleanup_session(session_id) # Should not fail + assert session_store.get_by_session_id(session_id) is None + + def test_cleanup(self, session_store): + """Test global cleanup removes all stored sessions.""" + session_id1 = "session-1" + session_id2 = "session-2" + + session_store.add_mapping(session_id1, "data1") + session_store.add_mapping(session_id2, "data2") + + # Ensure sessions exist before cleanup + assert session_store.get_by_session_id(session_id1) is not None + assert session_store.get_by_session_id(session_id2) is not None + + session_store.cleanup() + + # Ensure all sessions are removed after cleanup + assert session_store.get_by_session_id(session_id1) is None + assert session_store.get_by_session_id(session_id2) is None From 37922d45c28297f8deeb484002be87bccb908b72 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Fri, 21 Feb 2025 11:46:10 +0100 Subject: [PATCH 5/8] changes from rebase --- src/codegate/pipeline/pii/pii.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/codegate/pipeline/pii/pii.py b/src/codegate/pipeline/pii/pii.py index bfacc8a8..d8e5f988 100644 --- a/src/codegate/pipeline/pii/pii.py +++ b/src/codegate/pipeline/pii/pii.py @@ -1,4 +1,3 @@ -import re from typing import Any, Dict, List, Optional, Tuple import uuid From 58f5e93264eec204b533f98456c2bec7fa004035 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Mon, 24 Feb 2025 09:48:58 +0100 Subject: [PATCH 6/8] fixes from review --- src/codegate/pipeline/base.py | 3 --- src/codegate/pipeline/pii/pii.py | 6 ++--- .../pipeline/sensitive_data/manager.py | 26 +++++++------------ 3 files changed, 12 insertions(+), 23 deletions(-) diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index 29c60b62..ddcd5a61 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -21,10 +21,7 @@ class PipelineSensitiveData: manager: SensitiveDataManager session_id: str - api_key: Optional[str] = None model: Optional[str] = None - provider: Optional[str] = None - api_base: Optional[str] = None def secure_cleanup(self): """Securely cleanup sensitive data for this session""" diff --git a/src/codegate/pipeline/pii/pii.py b/src/codegate/pipeline/pii/pii.py index d8e5f988..ee97604b 100644 --- a/src/codegate/pipeline/pii/pii.py +++ b/src/codegate/pipeline/pii/pii.py @@ -141,7 +141,7 @@ async def process( total_pii_found = 0 all_pii_details: List[Dict[str, Any]] = [] last_redacted_text = "" - session_id = context.session_id if hasattr(context, "session_id") else str(uuid.uuid4()) + session_id = context.session_id if hasattr(context, "session_id") else None for i, message in enumerate(new_request["messages"]): if "content" in message and message["content"]: @@ -271,13 +271,13 @@ async def process_chunk( # noqa: C901 current_pos = 0 result = [] while current_pos < len(content): - start_idx = content.find("#", current_pos) + start_idx = content.find(self.marker_start, current_pos) if start_idx == -1: # No more markers!, add remaining content result.append(content[current_pos:]) break - end_idx = content.find("#", start_idx + 1) + end_idx = content.find(self.marker_end, start_idx + 1) if end_idx == -1: # Incomplete marker, buffer the rest context.prefix_buffer = content[current_pos:] diff --git a/src/codegate/pipeline/sensitive_data/manager.py b/src/codegate/pipeline/sensitive_data/manager.py index 08072bdc..d2fc20ca 100644 --- a/src/codegate/pipeline/sensitive_data/manager.py +++ b/src/codegate/pipeline/sensitive_data/manager.py @@ -1,12 +1,13 @@ import json from typing import Dict, Optional +import pydantic import structlog from codegate.pipeline.sensitive_data.session_store import SessionStore logger = structlog.get_logger("codegate") -class SensitiveData: +class SensitiveData(pydantic.BaseModel): """Represents sensitive data with additional metadata.""" def __init__(self, original: str, service: Optional[str], type: Optional[str]): @@ -14,16 +15,6 @@ def __init__(self, original: str, service: Optional[str], type: Optional[str]): self.service = service self.type = type - def to_json(self) -> str: - """Serializes the object to a JSON string.""" - return json.dumps({key: value for key, value in vars(self).items() if value is not None}) - - @staticmethod - def from_json(data: str) -> "SensitiveData": - """Deserializes from a JSON string to a SensitiveData object.""" - obj = json.loads(data) - return SensitiveData(obj["original"], obj.get("service"), obj.get("type")) - class SensitiveDataManager: """Manages encryption, storage, and retrieval of secrets""" @@ -32,24 +23,25 @@ def __init__(self): self.session_store = SessionStore() def store(self, session_id: str, value: SensitiveData) -> Optional[str]: - print("in store") if not session_id or not value.original: return None - print("i call add mapping") - print(self.session_store) - return self.session_store.add_mapping(session_id, value.to_json()) + return self.session_store.add_mapping(session_id, value.model_dump_json()) def get_by_session_id(self, session_id: str) -> Optional[Dict]: if not session_id: return None data = self.session_store.get_by_session_id(session_id) - return SensitiveData.from_json(data) if data else None + return SensitiveData.model_validate_json(data) if data else None def get_original_value(self, session_id: str, uuid_placeholder: str) -> Optional[str]: if not session_id: return None secret_entry_json = self.session_store.get_mapping(session_id, uuid_placeholder) - return SensitiveData.from_json(secret_entry_json).original if secret_entry_json else None + return ( + SensitiveData.model_validate_json(secret_entry_json).original + if secret_entry_json + else None + ) def cleanup_session(self, session_id: str): if session_id: From 606889b7540181476e5b89b97dff7989929907ea Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Mon, 24 Feb 2025 10:14:15 +0100 Subject: [PATCH 7/8] fixes in tests --- src/codegate/pipeline/pii/pii.py | 2 +- src/codegate/pipeline/secrets/secrets.py | 2 +- src/codegate/pipeline/sensitive_data/manager.py | 7 +++---- tests/pipeline/secrets/test_secrets.py | 4 ++-- tests/pipeline/sensitive_data/test_manager.py | 8 +++++--- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/codegate/pipeline/pii/pii.py b/src/codegate/pipeline/pii/pii.py index ee97604b..2b28ca6f 100644 --- a/src/codegate/pipeline/pii/pii.py +++ b/src/codegate/pipeline/pii/pii.py @@ -81,7 +81,7 @@ def process_results( pii_value = text[result.start : result.end] # add to session store - obj = SensitiveData(pii_value, "pii", result.entity_type) + obj = SensitiveData(original=pii_value, service="pii", type=result.entity_type) uuid_placeholder = self.sensitive_data_manager.store(session_id, obj) anonymized_text = anonymized_text.replace(pii_value, uuid_placeholder) diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index c0b83085..527c817f 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -194,7 +194,7 @@ def _hide_secret(self, match: Match) -> str: if not match.type: raise ValueError("Secret type must be provided") - obj = SensitiveData(match.value, match.service, match.type) + obj = SensitiveData(original=match.value, service=match.service, type=match.type) uuid_placeholder = self._sensitive_data_manager.store(self._session_id, obj) logger.debug( "Stored secret", service=match.service, type=match.type, placeholder=uuid_placeholder diff --git a/src/codegate/pipeline/sensitive_data/manager.py b/src/codegate/pipeline/sensitive_data/manager.py index d2fc20ca..89506d15 100644 --- a/src/codegate/pipeline/sensitive_data/manager.py +++ b/src/codegate/pipeline/sensitive_data/manager.py @@ -10,10 +10,9 @@ class SensitiveData(pydantic.BaseModel): """Represents sensitive data with additional metadata.""" - def __init__(self, original: str, service: Optional[str], type: Optional[str]): - self.original = original - self.service = service - self.type = type + original: str + service: Optional[str] = None + type: Optional[str] = None class SensitiveDataManager: diff --git a/tests/pipeline/secrets/test_secrets.py b/tests/pipeline/secrets/test_secrets.py index 0c471d4e..3f272b5b 100644 --- a/tests/pipeline/secrets/test_secrets.py +++ b/tests/pipeline/secrets/test_secrets.py @@ -186,7 +186,7 @@ def setup_method(self): async def test_complete_marker_processing(self): """Test processing of a complete REDACTED marker""" # Store a secret - obj = SensitiveData("secret_value", "test_service", "api_key") + obj = SensitiveData(original="secret_value", service="test_service", type="api_key") encrypted = self.sensitive_data_manager.store(self.session_id, obj) # Add content with REDACTED marker to buffer @@ -272,7 +272,7 @@ async def test_no_markers(self): async def test_wrong_session(self): """Test unredaction with wrong session ID""" # Store secret with one session - obj = SensitiveData("test_service", "api_key", "different_session") + obj = SensitiveData(original="test_service", service="api_key", type="different_session") encrypted = self.sensitive_data_manager.store("different_session", obj) # Try to unredact with different session diff --git a/tests/pipeline/sensitive_data/test_manager.py b/tests/pipeline/sensitive_data/test_manager.py index 44b2b857..6115ad14 100644 --- a/tests/pipeline/sensitive_data/test_manager.py +++ b/tests/pipeline/sensitive_data/test_manager.py @@ -22,7 +22,7 @@ def manager(self, mock_session_store): def test_store_success(self, manager, mock_session_store): """Test storing a SensitiveData object successfully.""" session_id = "session-123" - sensitive_data = SensitiveData("secret_value", "AWS", "API_KEY") + sensitive_data = SensitiveData(original="secret_value", service="AWS", type="API_KEY") # Mock session store behavior mock_session_store.add_mapping.return_value = "uuid-123" @@ -30,12 +30,14 @@ def test_store_success(self, manager, mock_session_store): result = manager.store(session_id, sensitive_data) # Verify correct function calls - mock_session_store.add_mapping.assert_called_once_with(session_id, sensitive_data.to_json()) + mock_session_store.add_mapping.assert_called_once_with( + session_id, sensitive_data.model_dump_json() + ) assert result == "uuid-123" def test_store_invalid_session_id(self, manager): """Test storing data with an invalid session ID (should return None).""" - sensitive_data = SensitiveData("secret_value", "AWS", "API_KEY") + sensitive_data = SensitiveData(original="secret_value", service="AWS", type="API_KEY") result = manager.store("", sensitive_data) # Empty session ID assert result is None From 0c63a2598c5f83334eed71924678e8cf8b6e7049 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Tue, 4 Mar 2025 13:54:32 +0100 Subject: [PATCH 8/8] fix tests --- src/codegate/pipeline/factory.py | 4 ++-- src/codegate/pipeline/pii/pii.py | 8 ++++---- tests/pipeline/pii/test_pi.py | 12 ++++++++---- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/codegate/pipeline/factory.py b/src/codegate/pipeline/factory.py index 3a2e3479..813459d5 100644 --- a/src/codegate/pipeline/factory.py +++ b/src/codegate/pipeline/factory.py @@ -32,7 +32,7 @@ def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelinePr # and without obfuscating the secrets, we'd leak the secrets during those # later steps CodegateSecrets(), - CodegatePii(), + CodegatePii(self.sensitive_data_manager), CodegateCli(), CodegateContextRetriever(), SystemPrompt( @@ -49,7 +49,7 @@ def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelinePr def create_fim_pipeline(self, client_type: ClientType) -> SequentialPipelineProcessor: fim_steps: List[PipelineStep] = [ CodegateSecrets(), - CodegatePii(), + CodegatePii(self.sensitive_data_manager), ] return SequentialPipelineProcessor( fim_steps, diff --git a/src/codegate/pipeline/pii/pii.py b/src/codegate/pipeline/pii/pii.py index 2b28ca6f..fde89428 100644 --- a/src/codegate/pipeline/pii/pii.py +++ b/src/codegate/pipeline/pii/pii.py @@ -44,10 +44,10 @@ class CodegatePii(PipelineStep): Restores the original PII from the anonymized text using the SensitiveDataManager. """ - def __init__(self): + def __init__(self, sensitive_data_manager: SensitiveDataManager): """Initialize the CodegatePii pipeline step.""" super().__init__() - self.sensitive_data_manager = SensitiveDataManager() + self.sensitive_data_manager = sensitive_data_manager self.analyzer = PiiAnalyzer.get_instance() @property @@ -141,7 +141,7 @@ async def process( total_pii_found = 0 all_pii_details: List[Dict[str, Any]] = [] last_redacted_text = "" - session_id = context.session_id if hasattr(context, "session_id") else None + session_id = context.sensitive.session_id for i, message in enumerate(new_request["messages"]): if "content" in message and message["content"]: @@ -257,7 +257,7 @@ async def process_chunk( # noqa: C901 return [chunk] content = chunk.choices[0].delta.content - session_id = input_context.metadata.get("session_id", "") + session_id = input_context.sensitive.session_id if not session_id: logger.error("Could not get any session id, cannot process pii") return [chunk] diff --git a/tests/pipeline/pii/test_pi.py b/tests/pipeline/pii/test_pi.py index ac0edb83..06d2881f 100644 --- a/tests/pipeline/pii/test_pi.py +++ b/tests/pipeline/pii/test_pi.py @@ -4,9 +4,10 @@ from litellm import ChatCompletionRequest, ModelResponse from litellm.types.utils import Delta, StreamingChoices -from codegate.pipeline.base import PipelineContext +from codegate.pipeline.base import PipelineContext, PipelineSensitiveData from codegate.pipeline.output import OutputPipelineContext from codegate.pipeline.pii.pii import CodegatePii, PiiRedactionNotifier, PiiUnRedactionStep +from codegate.pipeline.sensitive_data.manager import SensitiveDataManager class TestCodegatePii: @@ -19,8 +20,9 @@ def mock_config(self): yield mock_config @pytest.fixture - def pii_step(self, mock_config): - return CodegatePii() + def pii_step(self): + mock_sensitive_data_manager = MagicMock() + return CodegatePii(mock_sensitive_data_manager) def test_name(self, pii_step): assert pii_step.name == "codegate-pii" @@ -106,7 +108,9 @@ async def test_process_chunk_with_uuid(self, unredaction_step): object="chat.completion.chunk", ) context = OutputPipelineContext() - input_context = PipelineContext(metadata={"session_id": "session-id"}) + manager = SensitiveDataManager() + sensitive = PipelineSensitiveData(manager=manager, session_id="session-id") + input_context = PipelineContext(sensitive=sensitive) # Mock PII manager in input context mock_sensitive_data_manager = MagicMock()