From 51cf19653a2eef0686fad5c8238be8e37f3c17e7 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Date: Tue, 7 Jan 2025 10:30:42 +0100 Subject: [PATCH] feat: improve cache system to collect the last output Due to the cache system, we were collecting only the initial output of fim, that was including incomplete output. Add an update method to the cache, so we can collect all the output that comes from fim, associated to the same request Closes: #472 --- src/codegate/dashboard/dashboard.py | 4 +- src/codegate/db/connection.py | 99 ++++++++++++++++++++--------- src/codegate/db/fim_cache.py | 69 ++++++++++++++------ tests/db/test_fim_cache.py | 11 ++-- 4 files changed, 129 insertions(+), 54 deletions(-) diff --git a/src/codegate/dashboard/dashboard.py b/src/codegate/dashboard/dashboard.py index 4ed39f34..19352b51 100644 --- a/src/codegate/dashboard/dashboard.py +++ b/src/codegate/dashboard/dashboard.py @@ -1,5 +1,5 @@ import asyncio -from typing import AsyncGenerator, List +from typing import AsyncGenerator, List, Optional import structlog from fastapi import APIRouter, Depends @@ -36,7 +36,7 @@ def get_messages(db_reader: DbReader = Depends(get_db_reader)) -> List[Conversat @dashboard_router.get("/dashboard/alerts") -def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[AlertConversation]: +def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[Optional[AlertConversation]]: """ Get all the messages from the database and return them as a list of conversations. """ diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index c8fb60d0..af7c3b98 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -5,7 +5,7 @@ import structlog from pydantic import BaseModel -from sqlalchemy import text +from sqlalchemy import TextClause, text from sqlalchemy.ext.asyncio import create_async_engine from codegate.db.fim_cache import FimCache @@ -30,8 +30,8 @@ def __init__(self, sqlite_path: Optional[str] = None): current_dir = Path(__file__).parent sqlite_path = ( current_dir.parent.parent.parent / "codegate_volume" / "db" / "codegate.db" - ) - self._db_path = Path(sqlite_path).absolute() + ) # type: ignore + self._db_path = Path(sqlite_path).absolute() # type: ignore self._db_path.parent.mkdir(parents=True, exist_ok=True) logger.debug(f"Initializing DB from path: {self._db_path}") engine_dict = { @@ -82,15 +82,15 @@ async def init_db(self): finally: await self._async_db_engine.dispose() - async def _insert_pydantic_model( - self, model: BaseModel, sql_insert: text + async def _execute_update_pydantic_model( + self, model: BaseModel, sql_command: TextClause # ) -> Optional[BaseModel]: # There are create method in queries.py automatically generated by sqlc # However, the methods are buggy for Pydancti and don't work as expected. # Manually writing the SQL query to insert Pydantic models. async with self._async_db_engine.begin() as conn: try: - result = await conn.execute(sql_insert, model.model_dump()) + result = await conn.execute(sql_command, model.model_dump()) row = result.first() if row is None: return None @@ -99,7 +99,7 @@ async def _insert_pydantic_model( model_class = model.__class__ return model_class(**row._asdict()) except Exception as e: - logger.error(f"Failed to insert model: {model}.", error=str(e)) + logger.error(f"Failed to update model: {model}.", error=str(e)) return None async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]: @@ -112,18 +112,39 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option RETURNING * """ ) - recorded_request = await self._insert_pydantic_model(prompt_params, sql) + recorded_request = await self._execute_update_pydantic_model(prompt_params, sql) # Uncomment to debug the recorded request # logger.debug(f"Recorded request: {recorded_request}") - return recorded_request + return recorded_request # type: ignore - async def record_outputs(self, outputs: List[Output]) -> Optional[Output]: + async def update_request(self, initial_id: str, + prompt_params: Optional[Prompt] = None) -> Optional[Prompt]: + if prompt_params is None: + return None + prompt_params.id = initial_id # overwrite the initial id of the request + sql = text( + """ + UPDATE prompts + SET timestamp = :timestamp, provider = :provider, request = :request, type = :type + WHERE id = :id + RETURNING * + """ + ) + updated_request = await self._execute_update_pydantic_model(prompt_params, sql) + # Uncomment to debug the recorded request + # logger.debug(f"Recorded request: {recorded_request}") + return updated_request # type: ignore + + async def record_outputs(self, outputs: List[Output], + initial_id: Optional[str]) -> Optional[Output]: if not outputs: return first_output = outputs[0] # Create a single entry on DB but encode all of the chunks in the stream as a list # of JSON objects in the field `output` + if initial_id: + first_output.prompt_id = initial_id output_db = Output( id=first_output.id, prompt_id=first_output.prompt_id, @@ -143,14 +164,14 @@ async def record_outputs(self, outputs: List[Output]) -> Optional[Output]: RETURNING * """ ) - recorded_output = await self._insert_pydantic_model(output_db, sql) + recorded_output = await self._execute_update_pydantic_model(output_db, sql) # Uncomment to debug # logger.debug(f"Recorded output: {recorded_output}") - return recorded_output + return recorded_output # type: ignore - async def record_alerts(self, alerts: List[Alert]) -> List[Alert]: + async def record_alerts(self, alerts: List[Alert], initial_id: Optional[str]) -> List[Alert]: if not alerts: - return + return [] sql = text( """ INSERT INTO alerts ( @@ -167,7 +188,9 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]: async with asyncio.TaskGroup() as tg: for alert in alerts: try: - result = tg.create_task(self._insert_pydantic_model(alert, sql)) + if initial_id: + alert.prompt_id = initial_id + result = tg.create_task(self._execute_update_pydantic_model(alert, sql)) alerts_tasks.append(result) except Exception as e: logger.error(f"Failed to record alert: {alert}.", error=str(e)) @@ -182,33 +205,49 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]: # logger.debug(f"Recorded alerts: {recorded_alerts}") return recorded_alerts - def _should_record_context(self, context: Optional[PipelineContext]) -> bool: - """Check if the context should be recorded in DB""" + def _should_record_context(self, context: Optional[PipelineContext]) -> tuple: + """Check if the context should be recorded in DB and determine the action.""" if context is None or context.metadata.get("stored_in_db", False): - return False + return False, None, None if not context.input_request: logger.warning("No input request found. Skipping recording context.") - return False + return False, None, None # If it's not a FIM prompt, we don't need to check anything else. if context.input_request.type != "fim": - return True + return True, 'add', '' # Default to add if not FIM, since no cache check is required - return fim_cache.could_store_fim_request(context) + return fim_cache.could_store_fim_request(context) # type: ignore async def record_context(self, context: Optional[PipelineContext]) -> None: try: - if not self._should_record_context(context): + if not context: + logger.info("No context provided, skipping") return - await self.record_request(context.input_request) - await self.record_outputs(context.output_responses) - await self.record_alerts(context.alerts_raised) - context.metadata["stored_in_db"] = True - logger.info( - f"Recorded context in DB. Output chunks: {len(context.output_responses)}. " - f"Alerts: {len(context.alerts_raised)}." - ) + should_record, action, initial_id = self._should_record_context(context) + if not should_record: + logger.info("Skipping record of context, not needed") + return + if action == 'add': + await self.record_request(context.input_request) + await self.record_outputs(context.output_responses, None) + await self.record_alerts(context.alerts_raised, None) + context.metadata["stored_in_db"] = True + logger.info( + f"Recorded context in DB. Output chunks: {len(context.output_responses)}. " + f"Alerts: {len(context.alerts_raised)}." + ) + else: + # update them + await self.update_request(initial_id, context.input_request) + await self.record_outputs(context.output_responses, initial_id) + await self.record_alerts(context.alerts_raised, initial_id) + context.metadata["stored_in_db"] = True + logger.info( + f"Recorded context in DB. Output chunks: {len(context.output_responses)}. " + f"Alerts: {len(context.alerts_raised)}." + ) except Exception as e: logger.error(f"Failed to record context: {context}.", error=str(e)) diff --git a/src/codegate/db/fim_cache.py b/src/codegate/db/fim_cache.py index 2a2d8761..e5a488b6 100644 --- a/src/codegate/db/fim_cache.py +++ b/src/codegate/db/fim_cache.py @@ -18,6 +18,7 @@ class CachedFim(BaseModel): timestamp: datetime.datetime critical_alerts: List[Alert] + initial_id: str class FimCache: @@ -86,16 +87,42 @@ def _calculate_hash_key(self, message: str, provider: str) -> str: def _add_cache_entry(self, hash_key: str, context: PipelineContext): """Add a new cache entry""" + if not context.input_request: + logger.warning("No input request found. Skipping creating a mapping entry") + return critical_alerts = [ alert for alert in context.alerts_raised if alert.trigger_category == AlertSeverity.CRITICAL.value ] new_cache = CachedFim( - timestamp=context.input_request.timestamp, critical_alerts=critical_alerts + timestamp=context.input_request.timestamp, critical_alerts=critical_alerts, + initial_id=context.input_request.id ) self.cache[hash_key] = new_cache logger.info(f"Added cache entry for hash key: {hash_key}") + return self.cache[hash_key] + + def _update_cache_entry(self, hash_key: str, context: PipelineContext): + """Update an existing cache entry without changing the timestamp.""" + existing_entry = self.cache.get(hash_key) + if existing_entry is not None: + # Update critical alerts while retaining the original timestamp. + critical_alerts = [ + alert + for alert in context.alerts_raised + if alert.trigger_category == AlertSeverity.CRITICAL.value + ] + # Update the entry in the cache with new critical alerts but keep the old timestamp. + updated_cache = CachedFim( + timestamp=existing_entry.timestamp, critical_alerts=critical_alerts, + initial_id=existing_entry.initial_id + ) + self.cache[hash_key] = updated_cache + logger.info(f"Updated cache entry for hash key: {hash_key}") + else: + # Log a warning if trying to update a non-existent entry - ideally should not happen. + logger.warning(f"Attempted to update non-existent cache entry for hash key: {hash_key}") def _are_new_alerts_present(self, context: PipelineContext, cached_entry: CachedFim) -> bool: """Check if there are new alerts present""" @@ -108,29 +135,35 @@ def _are_new_alerts_present(self, context: PipelineContext, cached_entry: Cached def _is_cached_entry_old(self, context: PipelineContext, cached_entry: CachedFim) -> bool: """Check if the cached entry is old""" + if not context.input_request: + logger.warning("No input request found. Skipping checking if the cache entry is old") + return False elapsed_seconds = (context.input_request.timestamp - cached_entry.timestamp).total_seconds() - return elapsed_seconds > Config.get_config().max_fim_hash_lifetime + config = Config.get_config() + if config is None: + logger.warning("No configuration found. Skipping checking if the cache entry is old") + return True + return elapsed_seconds > Config.get_config().max_fim_hash_lifetime # type: ignore def could_store_fim_request(self, context: PipelineContext): + if not context.input_request: + logger.warning("No input request found. Skipping creating a mapping entry") + return False, '', '' # Couldn't process the user message. Skip creating a mapping entry. message = self._extract_message_from_fim_request(context.input_request.request) if message is None: logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.") - return False + return False, '', '' - hash_key = self._calculate_hash_key(message, context.input_request.provider) + hash_key = self._calculate_hash_key(message, context.input_request.provider) # type: ignore cached_entry = self.cache.get(hash_key, None) - if cached_entry is None: - self._add_cache_entry(hash_key, context) - return True - - if self._is_cached_entry_old(context, cached_entry): - self._add_cache_entry(hash_key, context) - return True - - if self._are_new_alerts_present(context, cached_entry): - self._add_cache_entry(hash_key, context) - return True - - logger.debug(f"FIM entry already in cache: {hash_key}.") - return False + if cached_entry is None or self._is_cached_entry_old( + context, cached_entry) or self._are_new_alerts_present(context, cached_entry): + cached_entry = self._add_cache_entry(hash_key, context) + if cached_entry is None: + logger.warning("Failed to add cache entry") + return False, '', '' + return True, 'add', cached_entry.initial_id + + self._update_cache_entry(hash_key, context) + return True, 'update', cached_entry.initial_id diff --git a/tests/db/test_fim_cache.py b/tests/db/test_fim_cache.py index 6da4de5f..c6b5506e 100644 --- a/tests/db/test_fim_cache.py +++ b/tests/db/test_fim_cache.py @@ -127,7 +127,7 @@ def test_extract_message_from_fim_request(test_request, expected_result_content) def test_are_new_alerts_present(): fim_cache = FimCache() - cached_entry = CachedFim(timestamp=datetime.now(), critical_alerts=[]) + cached_entry = CachedFim(timestamp=datetime.now(), critical_alerts=[], initial_id="1") context = PipelineContext() context.alerts_raised = [mock.MagicMock(trigger_category=AlertSeverity.CRITICAL.value)] result = fim_cache._are_new_alerts_present(context, cached_entry) @@ -146,6 +146,7 @@ def test_are_new_alerts_present(): trigger_string=None, ) ], + initial_id='2' ) result = fim_cache._are_new_alerts_present(context, populated_cache) assert result is False @@ -155,15 +156,17 @@ def test_are_new_alerts_present(): "cached_entry, is_old", [ ( - CachedFim(timestamp=datetime.now(timezone.utc) - timedelta(days=1), critical_alerts=[]), + CachedFim(timestamp=datetime.now(timezone.utc) - timedelta(days=1), + critical_alerts=[], initial_id='1'), True, ), - (CachedFim(timestamp=datetime.now(timezone.utc), critical_alerts=[]), False), + (CachedFim(timestamp=datetime.now(timezone.utc), critical_alerts=[], + initial_id='2'), False), ], ) def test_is_cached_entry_old(cached_entry, is_old): context = PipelineContext() - context.add_input_request("test", True, "test_provider") + context.add_input_request("test", True, "test_provider") # type: ignore fim_cache = FimCache() result = fim_cache._is_cached_entry_old(context, cached_entry) assert result == is_old