Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 348e36a

Browse files
authored
Merge pull request #497 from stacklok/issue-472-v2
feat: improve cache system to collect the last output
2 parents ba4a9aa + 51cf196 commit 348e36a

File tree

4 files changed

+129
-54
lines changed

4 files changed

+129
-54
lines changed

src/codegate/dashboard/dashboard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from typing import AsyncGenerator, List
2+
from typing import AsyncGenerator, List, Optional
33

44
import structlog
55
from fastapi import APIRouter, Depends
@@ -36,7 +36,7 @@ def get_messages(db_reader: DbReader = Depends(get_db_reader)) -> List[Conversat
3636

3737

3838
@dashboard_router.get("/dashboard/alerts")
39-
def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[AlertConversation]:
39+
def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[Optional[AlertConversation]]:
4040
"""
4141
Get all the messages from the database and return them as a list of conversations.
4242
"""

src/codegate/db/connection.py

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import structlog
77
from pydantic import BaseModel
8-
from sqlalchemy import text
8+
from sqlalchemy import TextClause, text
99
from sqlalchemy.ext.asyncio import create_async_engine
1010

1111
from codegate.db.fim_cache import FimCache
@@ -30,8 +30,8 @@ def __init__(self, sqlite_path: Optional[str] = None):
3030
current_dir = Path(__file__).parent
3131
sqlite_path = (
3232
current_dir.parent.parent.parent / "codegate_volume" / "db" / "codegate.db"
33-
)
34-
self._db_path = Path(sqlite_path).absolute()
33+
) # type: ignore
34+
self._db_path = Path(sqlite_path).absolute() # type: ignore
3535
self._db_path.parent.mkdir(parents=True, exist_ok=True)
3636
logger.debug(f"Initializing DB from path: {self._db_path}")
3737
engine_dict = {
@@ -82,15 +82,15 @@ async def init_db(self):
8282
finally:
8383
await self._async_db_engine.dispose()
8484

85-
async def _insert_pydantic_model(
86-
self, model: BaseModel, sql_insert: text
85+
async def _execute_update_pydantic_model(
86+
self, model: BaseModel, sql_command: TextClause #
8787
) -> Optional[BaseModel]:
8888
# There are create method in queries.py automatically generated by sqlc
8989
# However, the methods are buggy for Pydancti and don't work as expected.
9090
# Manually writing the SQL query to insert Pydantic models.
9191
async with self._async_db_engine.begin() as conn:
9292
try:
93-
result = await conn.execute(sql_insert, model.model_dump())
93+
result = await conn.execute(sql_command, model.model_dump())
9494
row = result.first()
9595
if row is None:
9696
return None
@@ -99,7 +99,7 @@ async def _insert_pydantic_model(
9999
model_class = model.__class__
100100
return model_class(**row._asdict())
101101
except Exception as e:
102-
logger.error(f"Failed to insert model: {model}.", error=str(e))
102+
logger.error(f"Failed to update model: {model}.", error=str(e))
103103
return None
104104

105105
async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
@@ -112,18 +112,39 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option
112112
RETURNING *
113113
"""
114114
)
115-
recorded_request = await self._insert_pydantic_model(prompt_params, sql)
115+
recorded_request = await self._execute_update_pydantic_model(prompt_params, sql)
116116
# Uncomment to debug the recorded request
117117
# logger.debug(f"Recorded request: {recorded_request}")
118-
return recorded_request
118+
return recorded_request # type: ignore
119119

120-
async def record_outputs(self, outputs: List[Output]) -> Optional[Output]:
120+
async def update_request(self, initial_id: str,
121+
prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
122+
if prompt_params is None:
123+
return None
124+
prompt_params.id = initial_id # overwrite the initial id of the request
125+
sql = text(
126+
"""
127+
UPDATE prompts
128+
SET timestamp = :timestamp, provider = :provider, request = :request, type = :type
129+
WHERE id = :id
130+
RETURNING *
131+
"""
132+
)
133+
updated_request = await self._execute_update_pydantic_model(prompt_params, sql)
134+
# Uncomment to debug the recorded request
135+
# logger.debug(f"Recorded request: {recorded_request}")
136+
return updated_request # type: ignore
137+
138+
async def record_outputs(self, outputs: List[Output],
139+
initial_id: Optional[str]) -> Optional[Output]:
121140
if not outputs:
122141
return
123142

124143
first_output = outputs[0]
125144
# Create a single entry on DB but encode all of the chunks in the stream as a list
126145
# of JSON objects in the field `output`
146+
if initial_id:
147+
first_output.prompt_id = initial_id
127148
output_db = Output(
128149
id=first_output.id,
129150
prompt_id=first_output.prompt_id,
@@ -143,14 +164,14 @@ async def record_outputs(self, outputs: List[Output]) -> Optional[Output]:
143164
RETURNING *
144165
"""
145166
)
146-
recorded_output = await self._insert_pydantic_model(output_db, sql)
167+
recorded_output = await self._execute_update_pydantic_model(output_db, sql)
147168
# Uncomment to debug
148169
# logger.debug(f"Recorded output: {recorded_output}")
149-
return recorded_output
170+
return recorded_output # type: ignore
150171

151-
async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
172+
async def record_alerts(self, alerts: List[Alert], initial_id: Optional[str]) -> List[Alert]:
152173
if not alerts:
153-
return
174+
return []
154175
sql = text(
155176
"""
156177
INSERT INTO alerts (
@@ -167,7 +188,9 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
167188
async with asyncio.TaskGroup() as tg:
168189
for alert in alerts:
169190
try:
170-
result = tg.create_task(self._insert_pydantic_model(alert, sql))
191+
if initial_id:
192+
alert.prompt_id = initial_id
193+
result = tg.create_task(self._execute_update_pydantic_model(alert, sql))
171194
alerts_tasks.append(result)
172195
except Exception as e:
173196
logger.error(f"Failed to record alert: {alert}.", error=str(e))
@@ -182,33 +205,49 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
182205
# logger.debug(f"Recorded alerts: {recorded_alerts}")
183206
return recorded_alerts
184207

185-
def _should_record_context(self, context: Optional[PipelineContext]) -> bool:
186-
"""Check if the context should be recorded in DB"""
208+
def _should_record_context(self, context: Optional[PipelineContext]) -> tuple:
209+
"""Check if the context should be recorded in DB and determine the action."""
187210
if context is None or context.metadata.get("stored_in_db", False):
188-
return False
211+
return False, None, None
189212

190213
if not context.input_request:
191214
logger.warning("No input request found. Skipping recording context.")
192-
return False
215+
return False, None, None
193216

194217
# If it's not a FIM prompt, we don't need to check anything else.
195218
if context.input_request.type != "fim":
196-
return True
219+
return True, 'add', '' # Default to add if not FIM, since no cache check is required
197220

198-
return fim_cache.could_store_fim_request(context)
221+
return fim_cache.could_store_fim_request(context) # type: ignore
199222

200223
async def record_context(self, context: Optional[PipelineContext]) -> None:
201224
try:
202-
if not self._should_record_context(context):
225+
if not context:
226+
logger.info("No context provided, skipping")
203227
return
204-
await self.record_request(context.input_request)
205-
await self.record_outputs(context.output_responses)
206-
await self.record_alerts(context.alerts_raised)
207-
context.metadata["stored_in_db"] = True
208-
logger.info(
209-
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
210-
f"Alerts: {len(context.alerts_raised)}."
211-
)
228+
should_record, action, initial_id = self._should_record_context(context)
229+
if not should_record:
230+
logger.info("Skipping record of context, not needed")
231+
return
232+
if action == 'add':
233+
await self.record_request(context.input_request)
234+
await self.record_outputs(context.output_responses, None)
235+
await self.record_alerts(context.alerts_raised, None)
236+
context.metadata["stored_in_db"] = True
237+
logger.info(
238+
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
239+
f"Alerts: {len(context.alerts_raised)}."
240+
)
241+
else:
242+
# update them
243+
await self.update_request(initial_id, context.input_request)
244+
await self.record_outputs(context.output_responses, initial_id)
245+
await self.record_alerts(context.alerts_raised, initial_id)
246+
context.metadata["stored_in_db"] = True
247+
logger.info(
248+
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
249+
f"Alerts: {len(context.alerts_raised)}."
250+
)
212251
except Exception as e:
213252
logger.error(f"Failed to record context: {context}.", error=str(e))
214253

src/codegate/db/fim_cache.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class CachedFim(BaseModel):
1818

1919
timestamp: datetime.datetime
2020
critical_alerts: List[Alert]
21+
initial_id: str
2122

2223

2324
class FimCache:
@@ -86,16 +87,42 @@ def _calculate_hash_key(self, message: str, provider: str) -> str:
8687

8788
def _add_cache_entry(self, hash_key: str, context: PipelineContext):
8889
"""Add a new cache entry"""
90+
if not context.input_request:
91+
logger.warning("No input request found. Skipping creating a mapping entry")
92+
return
8993
critical_alerts = [
9094
alert
9195
for alert in context.alerts_raised
9296
if alert.trigger_category == AlertSeverity.CRITICAL.value
9397
]
9498
new_cache = CachedFim(
95-
timestamp=context.input_request.timestamp, critical_alerts=critical_alerts
99+
timestamp=context.input_request.timestamp, critical_alerts=critical_alerts,
100+
initial_id=context.input_request.id
96101
)
97102
self.cache[hash_key] = new_cache
98103
logger.info(f"Added cache entry for hash key: {hash_key}")
104+
return self.cache[hash_key]
105+
106+
def _update_cache_entry(self, hash_key: str, context: PipelineContext):
107+
"""Update an existing cache entry without changing the timestamp."""
108+
existing_entry = self.cache.get(hash_key)
109+
if existing_entry is not None:
110+
# Update critical alerts while retaining the original timestamp.
111+
critical_alerts = [
112+
alert
113+
for alert in context.alerts_raised
114+
if alert.trigger_category == AlertSeverity.CRITICAL.value
115+
]
116+
# Update the entry in the cache with new critical alerts but keep the old timestamp.
117+
updated_cache = CachedFim(
118+
timestamp=existing_entry.timestamp, critical_alerts=critical_alerts,
119+
initial_id=existing_entry.initial_id
120+
)
121+
self.cache[hash_key] = updated_cache
122+
logger.info(f"Updated cache entry for hash key: {hash_key}")
123+
else:
124+
# Log a warning if trying to update a non-existent entry - ideally should not happen.
125+
logger.warning(f"Attempted to update non-existent cache entry for hash key: {hash_key}")
99126

100127
def _are_new_alerts_present(self, context: PipelineContext, cached_entry: CachedFim) -> bool:
101128
"""Check if there are new alerts present"""
@@ -108,29 +135,35 @@ def _are_new_alerts_present(self, context: PipelineContext, cached_entry: Cached
108135

109136
def _is_cached_entry_old(self, context: PipelineContext, cached_entry: CachedFim) -> bool:
110137
"""Check if the cached entry is old"""
138+
if not context.input_request:
139+
logger.warning("No input request found. Skipping checking if the cache entry is old")
140+
return False
111141
elapsed_seconds = (context.input_request.timestamp - cached_entry.timestamp).total_seconds()
112-
return elapsed_seconds > Config.get_config().max_fim_hash_lifetime
142+
config = Config.get_config()
143+
if config is None:
144+
logger.warning("No configuration found. Skipping checking if the cache entry is old")
145+
return True
146+
return elapsed_seconds > Config.get_config().max_fim_hash_lifetime # type: ignore
113147

114148
def could_store_fim_request(self, context: PipelineContext):
149+
if not context.input_request:
150+
logger.warning("No input request found. Skipping creating a mapping entry")
151+
return False, '', ''
115152
# Couldn't process the user message. Skip creating a mapping entry.
116153
message = self._extract_message_from_fim_request(context.input_request.request)
117154
if message is None:
118155
logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.")
119-
return False
156+
return False, '', ''
120157

121-
hash_key = self._calculate_hash_key(message, context.input_request.provider)
158+
hash_key = self._calculate_hash_key(message, context.input_request.provider) # type: ignore
122159
cached_entry = self.cache.get(hash_key, None)
123-
if cached_entry is None:
124-
self._add_cache_entry(hash_key, context)
125-
return True
126-
127-
if self._is_cached_entry_old(context, cached_entry):
128-
self._add_cache_entry(hash_key, context)
129-
return True
130-
131-
if self._are_new_alerts_present(context, cached_entry):
132-
self._add_cache_entry(hash_key, context)
133-
return True
134-
135-
logger.debug(f"FIM entry already in cache: {hash_key}.")
136-
return False
160+
if cached_entry is None or self._is_cached_entry_old(
161+
context, cached_entry) or self._are_new_alerts_present(context, cached_entry):
162+
cached_entry = self._add_cache_entry(hash_key, context)
163+
if cached_entry is None:
164+
logger.warning("Failed to add cache entry")
165+
return False, '', ''
166+
return True, 'add', cached_entry.initial_id
167+
168+
self._update_cache_entry(hash_key, context)
169+
return True, 'update', cached_entry.initial_id

tests/db/test_fim_cache.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_extract_message_from_fim_request(test_request, expected_result_content)
127127

128128
def test_are_new_alerts_present():
129129
fim_cache = FimCache()
130-
cached_entry = CachedFim(timestamp=datetime.now(), critical_alerts=[])
130+
cached_entry = CachedFim(timestamp=datetime.now(), critical_alerts=[], initial_id="1")
131131
context = PipelineContext()
132132
context.alerts_raised = [mock.MagicMock(trigger_category=AlertSeverity.CRITICAL.value)]
133133
result = fim_cache._are_new_alerts_present(context, cached_entry)
@@ -146,6 +146,7 @@ def test_are_new_alerts_present():
146146
trigger_string=None,
147147
)
148148
],
149+
initial_id='2'
149150
)
150151
result = fim_cache._are_new_alerts_present(context, populated_cache)
151152
assert result is False
@@ -155,15 +156,17 @@ def test_are_new_alerts_present():
155156
"cached_entry, is_old",
156157
[
157158
(
158-
CachedFim(timestamp=datetime.now(timezone.utc) - timedelta(days=1), critical_alerts=[]),
159+
CachedFim(timestamp=datetime.now(timezone.utc) - timedelta(days=1),
160+
critical_alerts=[], initial_id='1'),
159161
True,
160162
),
161-
(CachedFim(timestamp=datetime.now(timezone.utc), critical_alerts=[]), False),
163+
(CachedFim(timestamp=datetime.now(timezone.utc), critical_alerts=[],
164+
initial_id='2'), False),
162165
],
163166
)
164167
def test_is_cached_entry_old(cached_entry, is_old):
165168
context = PipelineContext()
166-
context.add_input_request("test", True, "test_provider")
169+
context.add_input_request("test", True, "test_provider") # type: ignore
167170
fim_cache = FimCache()
168171
result = fim_cache._is_cached_entry_old(context, cached_entry)
169172
assert result == is_old

0 commit comments

Comments
 (0)