Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Run update call on recurring schedule #1268

Merged
merged 2 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from typing import List, Optional
from uuid import UUID

import cachetools.func
import requests
import structlog
from fastapi import APIRouter, Depends, HTTPException, Query, Response
from fastapi.responses import StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError

from codegate.config import API_DEFAULT_PAGE_SIZE, API_MAX_PAGE_SIZE
import codegate.muxing.models as mux_models
from codegate import Config, __version__
from codegate import __version__
from codegate.api import v1_models, v1_processing
from codegate.config import API_DEFAULT_PAGE_SIZE, API_MAX_PAGE_SIZE
from codegate.db.connection import AlreadyExistsError, DbReader
from codegate.db.models import AlertSeverity, AlertTriggerType, Persona, WorkspaceWithModel
from codegate.muxing.persona import (
Expand All @@ -20,7 +21,7 @@
PersonaSimilarDescriptionError,
)
from codegate.providers import crud as provendcrud
from codegate.updates.client import Origin, UpdateClient
from codegate.updates.client import Origin, get_update_client_singleton
from codegate.workspaces import crud

logger = structlog.get_logger("codegate")
Expand All @@ -32,7 +33,6 @@

# This is a singleton object
dbreader = DbReader()
update_client = UpdateClient(Config.get_config().update_service_url, __version__, dbreader)


def uniq_name(route: APIRoute):
Expand Down Expand Up @@ -728,10 +728,7 @@ async def stream_sse():
@v1.get("/version", tags=["Dashboard"], generate_unique_id_function=uniq_name)
async def version_check():
try:
if Config.get_config().use_update_service:
latest_version = await update_client.get_latest_version(Origin.FrontEnd)
else:
latest_version = v1_processing.fetch_latest_version()
latest_version = _get_latest_version()
# normalize the versions as github will return them with a 'v' prefix
current_version = __version__.lstrip("v")
latest_version_stripped = latest_version.lstrip("v")
Expand Down Expand Up @@ -885,3 +882,9 @@ async def delete_persona(persona_name: str):
except Exception:
logger.exception("Error while deleting persona")
raise HTTPException(status_code=500, detail="Internal server error")


@cachetools.func.ttl_cache(maxsize=128, ttl=20 * 60)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the caching here since I do not want to cache the backend calls.

def _get_latest_version():
update_client = get_update_client_singleton()
return update_client.get_latest_version(Origin.FrontEnd)
12 changes: 0 additions & 12 deletions src/codegate/api/v1_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from collections import defaultdict
from typing import AsyncGenerator, Dict, List, Optional, Tuple

import cachetools.func
import regex as re
import requests
import structlog

from codegate.api import v1_models
Expand Down Expand Up @@ -34,16 +32,6 @@
]


@cachetools.func.ttl_cache(maxsize=128, ttl=20 * 60)
def fetch_latest_version() -> str:
url = "https://api.github.com/repos/stacklok/codegate/releases/latest"
headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"}
response = requests.get(url, headers=headers, timeout=5)
response.raise_for_status()
data = response.json()
return data.get("tag_name", "unknown")


async def generate_sse_events() -> AsyncGenerator[str, None]:
"""
SSE generator from queue
Expand Down
13 changes: 12 additions & 1 deletion src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from uvicorn.config import Config as UvicornConfig
from uvicorn.server import Server

import codegate
from codegate.ca.codegate_ca import CertificateAuthority
from codegate.codegate_logging import LogFormat, LogLevel, setup_logging
from codegate.config import Config, ConfigurationError
Expand All @@ -25,6 +26,8 @@
from codegate.providers.copilot.provider import CopilotProvider
from codegate.server import init_app
from codegate.storage.utils import restore_storage_backup
from codegate.updates.client import init_update_client_singleton
from codegate.updates.scheduled import ScheduledUpdateChecker
from codegate.workspaces import crud as wscrud


Expand Down Expand Up @@ -322,9 +325,17 @@ def serve( # noqa: C901
logger = structlog.get_logger("codegate").bind(origin="cli")

init_db_sync(cfg.db_path)
init_instance(cfg.db_path)
instance_id = init_instance(cfg.db_path)
init_session_if_not_exists(cfg.db_path)

# Initialize the update checking logic.
update_client = init_update_client_singleton(
cfg.update_service_url, codegate.__version__, instance_id
)
update_checker = ScheduledUpdateChecker(update_client)
update_checker.daemon = True
update_checker.start()

# Check certificates and create CA if necessary
logger.info("Checking certificates and creating CA if needed")
ca = CertificateAuthority.get_instance()
Expand Down
Empty file.
10 changes: 0 additions & 10 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,6 @@ def from_env(cls) -> "Config":
config.db_path = os.environ["CODEGATE_DB_PATH"]
if "CODEGATE_VEC_DB_PATH" in os.environ:
config.vec_db_path = os.environ["CODEGATE_VEC_DB_PATH"]
if "CODEGATE_USE_UPDATE_SERVICE" in os.environ:
config.use_update_service = cls.__bool_from_string(
os.environ["CODEGATE_USE_UPDATE_SERVICE"]
)
if "CODEGATE_UPDATE_SERVICE_URL" in os.environ:
config.update_service_url = os.environ["CODEGATE_UPDATE_SERVICE_URL"]

Expand Down Expand Up @@ -258,7 +254,6 @@ def load(
force_certs: Optional[bool] = None,
db_path: Optional[str] = None,
vec_db_path: Optional[str] = None,
use_update_service: Optional[bool] = None,
update_service_url: Optional[str] = None,
) -> "Config":
"""Load configuration with priority resolution.
Expand Down Expand Up @@ -288,7 +283,6 @@ def load(
force_certs: Optional flag to force certificate generation
db_path: Optional path to the main SQLite database file
vec_db_path: Optional path to the vector SQLite database file
use_update_service: Optional flag to enable the update service
update_service_url: Optional URL for the update service

Returns:
Expand Down Expand Up @@ -342,8 +336,6 @@ def load(
config.db_path = env_config.db_path
if "CODEGATE_VEC_DB_PATH" in os.environ:
config.vec_db_path = env_config.vec_db_path
if "CODEGATE_USE_UPDATE_SERVICE" in os.environ:
config.use_update_service = env_config.use_update_service
if "CODEGATE_UPDATE_SERVICE_URL" in os.environ:
config.update_service_url = env_config.update_service_url

Expand Down Expand Up @@ -386,8 +378,6 @@ def load(
config.vec_db_path = vec_db_path
if force_certs is not None:
config.force_certs = force_certs
if use_update_service is not None:
config.use_update_service = use_update_service
if update_service_url is not None:
config.update_service_url = update_service_url

Expand Down
15 changes: 10 additions & 5 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,10 +600,11 @@ async def delete_persona(self, persona_id: str) -> None:
conditions = {"id": persona_id}
await self._execute_with_no_return(sql, conditions)

async def init_instance(self) -> None:
async def init_instance(self) -> str:
"""
Initializes instance details in the database.
"""
instance_id = str(uuid.uuid4())
sql = text(
"""
INSERT INTO instance (id, created_at)
Expand All @@ -613,13 +614,14 @@ async def init_instance(self) -> None:

try:
instance = Instance(
id=str(uuid.uuid4()),
id=instance_id,
created_at=datetime.datetime.now(datetime.timezone.utc),
)
await self._execute_with_no_return(sql, instance.model_dump())
except IntegrityError as e:
logger.debug(f"Exception type: {type(e)}")
raise AlreadyExistsError("Instance already initialized.")
return instance_id


class DbReader(DbCodeGate):
Expand Down Expand Up @@ -1326,18 +1328,21 @@ def init_session_if_not_exists(db_path: Optional[str] = None):
logger.info("Session in DB initialized successfully.")


def init_instance(db_path: Optional[str] = None):
def init_instance(db_path: Optional[str] = None) -> str:
db_reader = DbReader(db_path)
instance = asyncio.run(db_reader.get_instance())
# Initialize instance if not already initialized.
if not instance:
db_recorder = DbRecorder(db_path)
try:
asyncio.run(db_recorder.init_instance())
instance_id = asyncio.run(db_recorder.init_instance())
logger.info("Instance initialized successfully.")
return instance_id
except Exception as e:
logger.error(f"Failed to initialize instance in DB: {e}")
raise
logger.info("Instance initialized successfully.")
else:
return instance[0].id


if __name__ == "__main__":
Expand Down
43 changes: 23 additions & 20 deletions src/codegate/updates/client.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,32 @@
from enum import Enum

import cachetools.func
import requests
import structlog

from codegate.db.connection import DbReader

logger = structlog.get_logger("codegate")


__update_client_singleton = None


# Enum representing whether the request is coming from the front-end or the back-end.
class Origin(Enum):
FrontEnd = "FE"
BackEnd = "BE"


class UpdateClient:
def __init__(self, update_url: str, current_version: str, db_reader: DbReader):
def __init__(self, update_url: str, current_version: str, instance_id: str):
self.__update_url = update_url
self.__current_version = current_version
self.__db_reader = db_reader
self.__instance_id = None
self.__instance_id = instance_id

async def get_latest_version(self, origin: Origin) -> str:
def get_latest_version(self, origin: Origin) -> str:
"""
Retrieves the latest version of CodeGate from updates.codegate.ai
"""
logger.info(f"Fetching latest version from {self.__update_url}")
instance_id = await self.__get_instance_id()
return self.__fetch_latest_version(instance_id, origin)

@cachetools.func.ttl_cache(maxsize=128, ttl=20 * 60)
def __fetch_latest_version(self, instance_id: str, origin: Origin) -> str:
headers = {
"X-Instance-ID": instance_id,
"X-Instance-ID": self.__instance_id,
"User-Agent": f"codegate/{self.__current_version} {origin.value}",
}

Expand All @@ -46,9 +39,19 @@ def __fetch_latest_version(self, instance_id: str, origin: Origin) -> str:
logger.error(f"Error fetching latest version from f{self.__update_url}: {e}")
return "unknown"

# Lazy load the instance ID from the DB.
async def __get_instance_id(self):
if self.__instance_id is None:
instance_data = await self.__db_reader.get_instance()
self.__instance_id = instance_data[0].id
return self.__instance_id

# Use a singleton since we do not have a good way of doing dependency injection
# with the API endpoints.
def init_update_client_singleton(
update_url: str, current_version: str, instance_id: str
) -> UpdateClient:
global __update_client_singleton
__update_client_singleton = UpdateClient(update_url, current_version, instance_id)
return __update_client_singleton


def get_update_client_singleton() -> UpdateClient:
global __update_client_singleton
if __update_client_singleton is None:
raise ValueError("UpdateClient singleton not initialized")
return __update_client_singleton
34 changes: 34 additions & 0 deletions src/codegate/updates/scheduled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import threading
import time

import structlog

import codegate
from codegate.updates.client import Origin, UpdateClient

logger = structlog.get_logger("codegate")


class ScheduledUpdateChecker(threading.Thread):
"""
ScheduledUpdateChecker calls the UpdateClient on a recurring interval.
This is implemented as a separate thread to avoid blocking the main thread.
A dedicated scheduling library could have been used, but the requirements
are trivial, and a simple hand-rolled solution is sufficient.
"""

def __init__(self, client: UpdateClient, interval_seconds: int = 14400): # 4 hours in seconds
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we want this interval to be externally configurable?

super().__init__()
self.__client = client
self.__interval_seconds = interval_seconds

def run(self):
"""
Overrides the `run` method of threading.Thread.
"""
while True:
logger.info("Checking for CodeGate updates")
latest = self.__client.get_latest_version(Origin.BackEnd)
if latest != codegate.__version__:
logger.warning(f"A new version of CodeGate is available: {latest}")
time.sleep(self.__interval_seconds)
14 changes: 8 additions & 6 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,20 @@ def test_health_check(test_client: TestClient) -> None:
assert response.json() == {"status": "healthy"}


@patch("codegate.api.v1_processing.fetch_latest_version", return_value="foo")
def test_version_endpoint(mock_fetch_latest_version, test_client: TestClient) -> None:
@patch("codegate.api.v1._get_latest_version")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this test would use a patched out UpdateClient instance, but it's difficult to patch out a singleton with the mocks.

This sort of testing would be much easier if the V1 routes were defined in a class, and all dependencies were supplied via the constructor.

def test_version_endpoint(mock_get_latest_version, test_client: TestClient) -> None:
"""Test the version endpoint."""
# Mock the __get_latest_version function to return a specific version
mock_get_latest_version.return_value = "v1.2.3"

response = test_client.get("/api/v1/version")
assert response.status_code == 200

response_data = response.json()

assert response_data["current_version"] == __version__.lstrip("v")
assert response_data["latest_version"] == "foo"
assert isinstance(response_data["is_latest"], bool)
assert response_data["current_version"] == "0.1.7"
assert response_data["latest_version"] == "1.2.3"
assert response_data["is_latest"] is False
assert response_data["error"] is None


@patch("codegate.pipeline.sensitive_data.manager.SensitiveDataManager")
Expand Down
Loading