diff --git a/data/malicious.jsonl b/data/malicious.jsonl index 05e44b71..0ebe7eea 100644 --- a/data/malicious.jsonl +++ b/data/malicious.jsonl @@ -8,3 +8,4 @@ {"name":"malicious-go-dummy","type":"go","description":"Dummy malicious to test with simple package name on go"} {"name":"@prefix/malicious-crates-dummy","type":"crates","description":"Dummy malicious to test with encoded package name on crates"} {"name":"malicious-crates-dummy","type":"crates","description":"Dummy malicious to test with simple package name on crates"} +{"name":"invokehttp","type":"pypi","description":"Invokehttp is a malicious package"} diff --git a/src/codegate/providers/copilot/mapping.py b/src/codegate/providers/copilot/mapping.py index 46636238..c8094210 100644 --- a/src/codegate/providers/copilot/mapping.py +++ b/src/codegate/providers/copilot/mapping.py @@ -1,4 +1,6 @@ -from typing import List +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional from pydantic import BaseModel, HttpUrl from pydantic_settings import BaseSettings @@ -43,3 +45,26 @@ class CoPilotMappings(BaseSettings): VALIDATED_ROUTES: List[CopilotProxyRoute] = [ CopilotProxyRoute(path=path, target=target) for path, target in mappings.PROXY_ROUTES ] + + +class PipelineType(Enum): + FIM = "fim" + CHAT = "chat" + + +@dataclass +class PipelineRoute: + path: str + pipeline_type: PipelineType + target_url: Optional[str] = None + + +PIPELINE_ROUTES = [ + PipelineRoute( + path="v1/chat/completions", + # target_url="https://api.openai.com/v1/chat/completions", + pipeline_type=PipelineType.CHAT, + ), + PipelineRoute(path="v1/engines/copilot-codex/completions", pipeline_type=PipelineType.FIM), + PipelineRoute(path="chat/completions", pipeline_type=PipelineType.CHAT), +] diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index fc905efc..1e60e18e 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -15,7 +15,7 @@ from codegate.pipeline.factory import PipelineFactory from codegate.pipeline.output import OutputPipelineInstance from codegate.pipeline.secrets.manager import SecretsManager -from codegate.providers.copilot.mapping import VALIDATED_ROUTES +from codegate.providers.copilot.mapping import PIPELINE_ROUTES, VALIDATED_ROUTES, PipelineType from codegate.providers.copilot.pipeline import ( CopilotChatPipeline, CopilotFimPipeline, @@ -153,12 +153,18 @@ def __init__(self, loop: asyncio.AbstractEventLoop): self.context_tracking: Optional[PipelineContext] = None def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]: - if method == "POST" and path == "v1/engines/copilot-codex/completions": - logger.debug("Selected CopilotFimStrategy") - return CopilotFimPipeline(self.pipeline_factory) - if method == "POST" and path == "chat/completions": - logger.debug("Selected CopilotChatStrategy") - return CopilotChatPipeline(self.pipeline_factory) + if method != "POST": + logger.debug("Not a POST request, no pipeline selected") + return None + + for route in PIPELINE_ROUTES: + if path == route.path: + if route.pipeline_type == PipelineType.FIM: + logger.debug("Selected FIM pipeline") + return CopilotFimPipeline(self.pipeline_factory) + elif route.pipeline_type == PipelineType.CHAT: + logger.debug("Selected CHAT pipeline") + return CopilotChatPipeline(self.pipeline_factory) logger.debug("No pipeline selected") return None @@ -350,8 +356,82 @@ async def _forward_data_to_target(self, data: bytes) -> None: pipeline_output = pipeline_output.reconstruct() self.target_transport.write(pipeline_output) + def _has_complete_body(self) -> bool: + """ + Check if we have received the complete request body based on Content-Length header. + + We check the headers from the buffer instead of using self.request.headers on purpose + because with CONNECT requests, the whole request arrives in the data and is stored in + the buffer. + """ + try: + # For the initial CONNECT request + if not self.headers_parsed and self.request and self.request.method == "CONNECT": + return True + + # For subsequent requests or non-CONNECT requests, parse the method from the buffer + try: + first_line = self.buffer[: self.buffer.index(b"\r\n")].decode("utf-8") + method = first_line.split()[0] + except (ValueError, IndexError): + # Haven't received the complete request line yet + return False + + if method != "POST": # do we need to check for other methods? PUT? + return True + + # Parse headers from the buffer instead of using self.request.headers + headers_dict = {} + try: + headers_end = self.buffer.index(b"\r\n\r\n") + if headers_end <= 0: # Ensure we have a valid headers section + return False + + headers = self.buffer[:headers_end].split(b"\r\n") + if len(headers) <= 1: # Ensure we have headers after the request line + return False + + for header in headers[1:]: # Skip the request line + if not header: # Skip empty lines + continue + try: + name, value = header.decode("utf-8").split(":", 1) + headers_dict[name.strip().lower()] = value.strip() + except ValueError: + # Skip malformed headers + continue + except ValueError: + # Haven't received the complete headers yet + return False + + # TODO: Add proper support for chunked transfer encoding + # For now, just pass through and let the pipeline handle it + if "transfer-encoding" in headers_dict: + return True + + try: + content_length = int(headers_dict.get("content-length")) + except (ValueError, TypeError): + # Content-Length header is required for POST requests without chunked encoding + logger.error("Missing or invalid Content-Length header in POST request") + return False + + body_start = headers_end + 4 # Add safety check for buffer length + if body_start >= len(self.buffer): + return False + + current_body_length = len(self.buffer) - body_start + return current_body_length >= content_length + except Exception as e: + logger.error(f"Error checking body completion: {e}") + return False + def data_received(self, data: bytes) -> None: - """Handle received data from client""" + """ + Handle received data from client. Since we need to process the complete body + through our pipeline before forwarding, we accumulate the entire request first. + """ + logger.info(f"Received data from {self.peername}: {data}") try: if not self._check_buffer_size(data): self.send_error_response(413, b"Request body too large") @@ -364,10 +444,17 @@ def data_received(self, data: bytes) -> None: if self.headers_parsed: if self.request.method == "CONNECT": self.handle_connect() + self.buffer.clear() else: + # Only process the request once we have the complete body asyncio.create_task(self.handle_http_request()) else: - asyncio.create_task(self._forward_data_to_target(data)) + if self._has_complete_body(): + # Process the complete request through the pipeline + complete_request = bytes(self.buffer) + logger.debug(f"Complete request: {complete_request}") + self.buffer.clear() + asyncio.create_task(self._forward_data_to_target(complete_request)) except Exception as e: logger.error(f"Error processing received data: {e}") diff --git a/tests/integration/checks.py b/tests/integration/checks.py new file mode 100644 index 00000000..8ca964e6 --- /dev/null +++ b/tests/integration/checks.py @@ -0,0 +1,86 @@ +from abc import ABC, abstractmethod +from typing import List + +import structlog +from sklearn.metrics.pairwise import cosine_similarity + +from codegate.inference.inference_engine import LlamaCppInferenceEngine + +logger = structlog.get_logger("codegate") + + +class BaseCheck(ABC): + def __init__(self, test_name: str): + self.test_name = test_name + + @abstractmethod + async def run_check(self, parsed_response: str, test_data: dict) -> bool: + pass + + +class CheckLoader: + @staticmethod + def load(test_data: dict) -> List[BaseCheck]: + test_name = test_data.get("name") + checks = [] + if test_data.get(DistanceCheck.KEY): + checks.append(DistanceCheck(test_name)) + if test_data.get(ContainsCheck.KEY): + checks.append(ContainsCheck(test_name)) + if test_data.get(DoesNotContainCheck.KEY): + checks.append(DoesNotContainCheck(test_name)) + + return checks + + +class DistanceCheck(BaseCheck): + KEY = "likes" + + def __init__(self, test_name: str): + super().__init__(test_name) + self.inference_engine = LlamaCppInferenceEngine() + self.embedding_model = "codegate_volume/models/all-minilm-L6-v2-q5_k_m.gguf" + + async def _calculate_string_similarity(self, str1, str2): + vector1 = await self.inference_engine.embed(self.embedding_model, [str1]) + vector2 = await self.inference_engine.embed(self.embedding_model, [str2]) + similarity = cosine_similarity(vector1, vector2) + return similarity[0] + + async def run_check(self, parsed_response: str, test_data: dict) -> bool: + similarity = await self._calculate_string_similarity( + parsed_response, test_data[DistanceCheck.KEY] + ) + if similarity < 0.8: + logger.error(f"Test {self.test_name} failed") + logger.error(f"Similarity: {similarity}") + logger.error(f"Response: {parsed_response}") + logger.error(f"Expected Response: {test_data[DistanceCheck.KEY]}") + return False + return True + + +class ContainsCheck(BaseCheck): + KEY = "contains" + + async def run_check(self, parsed_response: str, test_data: dict) -> bool: + if test_data[ContainsCheck.KEY].strip() not in parsed_response: + logger.error(f"Test {self.test_name} failed") + logger.error(f"Response: {parsed_response}") + logger.error(f"Expected Response to contain: '{test_data[ContainsCheck.KEY]}'") + return False + return True + + +class DoesNotContainCheck(BaseCheck): + KEY = "does_not_contain" + + async def run_check(self, parsed_response: str, test_data: dict) -> bool: + if test_data[DoesNotContainCheck.KEY].strip() in parsed_response: + logger.error(f"Test {self.test_name} failed") + logger.error(f"Response: {parsed_response}") + logger.error( + f"Expected Response to not contain: '{test_data[DoesNotContainCheck.KEY]}'" + ) + return False + return True diff --git a/tests/integration/integration_tests.py b/tests/integration/integration_tests.py index f2044712..bfa89562 100644 --- a/tests/integration/integration_tests.py +++ b/tests/integration/integration_tests.py @@ -2,30 +2,50 @@ import json import os import re +from typing import Optional import requests import structlog import yaml +from checks import CheckLoader from dotenv import find_dotenv, load_dotenv -from sklearn.metrics.pairwise import cosine_similarity - -from codegate.inference.inference_engine import LlamaCppInferenceEngine +from requesters import RequesterFactory logger = structlog.get_logger("codegate") class CodegateTestRunner: def __init__(self): - self.inference_engine = LlamaCppInferenceEngine() - self.embedding_model = "codegate_volume/models/all-minilm-L6-v2-q5_k_m.gguf" + self.requester_factory = RequesterFactory() + + def call_codegate( + self, url: str, headers: dict, data: dict, provider: str + ) -> Optional[requests.Response]: + logger.debug(f"Creating requester for provider: {provider}") + requester = self.requester_factory.create_requester(provider) + logger.debug(f"Using requester type: {requester.__class__.__name__}") + + logger.debug(f"Making request to URL: {url}") + logger.debug(f"Headers: {headers}") + logger.debug(f"Data: {data}") + + response = requester.make_request(url, headers, data) + + # Enhanced response logging + if response is not None: + + if response.status_code != 200: + logger.debug(f"Response error status: {response.status_code}") + logger.debug(f"Response error headers: {dict(response.headers)}") + try: + error_content = response.json() + logger.error(f"Request error as JSON: {error_content}") + except ValueError: + # If not JSON, try to get raw text + logger.error(f"Raw request error: {response.text}") + else: + logger.error("No response received") - @staticmethod - def call_codegate(url, headers, data): - response = None - try: - response = requests.post(url, headers=headers, json=data) - except Exception as e: - logger.exception("An error occurred: %s", e) return response @staticmethod @@ -50,6 +70,8 @@ def parse_response_message(response, streaming=True): message_content = None if "choices" in json_line: + if "finish_reason" in json_line["choices"][0]: + break if "delta" in json_line["choices"][0]: message_content = json_line["choices"][0]["delta"].get("content", "") elif "text" in json_line["choices"][0]: @@ -75,12 +97,6 @@ def parse_response_message(response, streaming=True): return response_message - async def calculate_string_similarity(self, str1, str2): - vector1 = await self.inference_engine.embed(self.embedding_model, [str1]) - vector2 = await self.inference_engine.embed(self.embedding_model, [str2]) - similarity = cosine_similarity(vector1, vector2) - return similarity[0] - @staticmethod def replace_env_variables(input_string, env): """ @@ -103,51 +119,115 @@ def replacement(match): pattern = r"ENV\w*" return re.sub(pattern, replacement, input_string) - async def run_test(self, test, test_headers): + async def run_test(self, test: dict, test_headers: dict) -> None: test_name = test["name"] url = test["url"] data = json.loads(test["data"]) streaming = data.get("stream", False) - response = CodegateTestRunner.call_codegate(url, test_headers, data) - expected_response = test["expected"] + provider = test["provider"] + + response = self.call_codegate(url, test_headers, data, provider) + if not response: + logger.error(f"Test {test_name} failed: No response received") + return + + # Debug response info + logger.debug(f"Response status: {response.status_code}") + logger.debug(f"Response headers: {dict(response.headers)}") + try: - parsed_response = CodegateTestRunner.parse_response_message( - response, streaming=streaming - ) - similarity = await self.calculate_string_similarity(parsed_response, expected_response) - if similarity < 0.8: - logger.error(f"Test {test_name} failed") - logger.error(f"Similarity: {similarity}") - logger.error(f"Response: {parsed_response}") - logger.error(f"Expected Response: {expected_response}") - else: - logger.info(f"Test {test['name']} passed") + parsed_response = self.parse_response_message(response, streaming=streaming) + + # Load appropriate checks for this test + checks = CheckLoader.load(test) + + # Run all checks + passed = True + for check in checks: + passed_check = await check.run_check(parsed_response, test) + if not passed_check: + passed = False + logger.info(f"Test {test_name} passed" if passed else f"Test {test_name} failed") + except Exception as e: logger.exception("Could not parse response: %s", e) - async def run_tests(self, testcases_file): + async def run_tests( + self, + testcases_file: str, + providers: Optional[list[str]] = None, + test_names: Optional[list[str]] = None, + ) -> None: with open(testcases_file, "r") as f: tests = yaml.safe_load(f) headers = tests["headers"] - for _, header_val in headers.items(): - if header_val is None: - continue - for key, val in header_val.items(): - header_val[key] = CodegateTestRunner.replace_env_variables(val, os.environ) + testcases = tests["testcases"] - test_count = len(tests["testcases"]) + if providers or test_names: + filtered_testcases = {} - logger.info(f"Running {test_count} tests") - for _, test_data in tests["testcases"].items(): + for test_id, test_data in testcases.items(): + if providers: + if test_data.get("provider", "").lower() not in [p.lower() for p in providers]: + continue + + if test_names: + if test_data.get("name", "").lower() not in [t.lower() for t in test_names]: + continue + + filtered_testcases[test_id] = test_data + + testcases = filtered_testcases + + if not testcases: + filter_msg = [] + if providers: + filter_msg.append(f"providers: {', '.join(providers)}") + if test_names: + filter_msg.append(f"test names: {', '.join(test_names)}") + logger.warning(f"No tests found for {' and '.join(filter_msg)}") + return + + test_count = len(testcases) + filter_msg = [] + if providers: + filter_msg.append(f"providers: {', '.join(providers)}") + if test_names: + filter_msg.append(f"test names: {', '.join(test_names)}") + + logger.info( + f"Running {test_count} tests" + + (f" for {' and '.join(filter_msg)}" if filter_msg else "") + ) + + for test_id, test_data in testcases.items(): test_headers = headers.get(test_data["provider"], {}) + test_headers = { + k: self.replace_env_variables(v, os.environ) for k, v in test_headers.items() + } await self.run_test(test_data, test_headers) async def main(): load_dotenv(find_dotenv()) test_runner = CodegateTestRunner() - await test_runner.run_tests("./tests/integration/testcases.yaml") + + # Get providers and test names from environment variables + providers_env = os.environ.get("CODEGATE_PROVIDERS") + test_names_env = os.environ.get("CODEGATE_TEST_NAMES") + + providers = None + if providers_env: + providers = [p.strip() for p in providers_env.split(",") if p.strip()] + + test_names = None + if test_names_env: + test_names = [t.strip() for t in test_names_env.split(",") if t.strip()] + + await test_runner.run_tests( + "./tests/integration/testcases.yaml", providers=providers, test_names=test_names + ) if __name__ == "__main__": diff --git a/tests/integration/requesters.py b/tests/integration/requesters.py new file mode 100644 index 00000000..8441a51f --- /dev/null +++ b/tests/integration/requesters.py @@ -0,0 +1,54 @@ +import json +import os +from abc import ABC, abstractmethod +from typing import Optional + +import requests +import structlog + +logger = structlog.get_logger("codegate") + + +class BaseRequester(ABC): + @abstractmethod + def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]: + pass + + +class StandardRequester(BaseRequester): + def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]: + # Ensure Content-Type is always set correctly + headers["Content-Type"] = "application/json" + + # Explicitly serialize to JSON string + json_data = json.dumps(data) + + return requests.post( + url, headers=headers, data=json_data # Use data instead of json parameter + ) + + +class CopilotRequester(BaseRequester): + def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]: + # Ensure Content-Type is always set correctly + headers["Content-Type"] = "application/json" + + # Explicitly serialize to JSON string + json_data = json.dumps(data) + + return requests.post( + url, + data=json_data, # Use data instead of json parameter + headers=headers, + proxies={"https": "https://localhost:8990", "http": "http://localhost:8990"}, + verify=os.environ.get("CA_CERT_FILE"), + stream=True, + ) + + +class RequesterFactory: + @staticmethod + def create_requester(provider: str) -> BaseRequester: + if provider.lower() == "copilot": + return CopilotRequester() + return StandardRequester() diff --git a/tests/integration/testcases.yaml b/tests/integration/testcases.yaml index 6a314a80..4273a120 100644 --- a/tests/integration/testcases.yaml +++ b/tests/integration/testcases.yaml @@ -7,8 +7,49 @@ headers: llamacpp: anthropic: x-api-key: ENV_ANTHROPIC_KEY + copilot: + Authorization: Bearer ENV_COPILOT_KEY + Content-Type: application/json testcases: + copilot_chat: + name: Copilot Chat + provider: copilot + url: "https://api.openai.com/v1/chat/completions" + data: | + { + "messages":[ + { + "content":"Hello", + "role":"user" + } + ], + "model":"gpt-4o", + "stream":true + } + likes: | + Hello! How can I assist you today? + + copilot_malicious_package_question: + name: Copilot User asks about a malicious package + provider: copilot + url: "https://api.openai.com/v1/chat/completions" + data: | + { + "messages":[ + { + "content":"Generate me example code using the python invokehttp package to call an API", + "role":"user" + } + ], + "model":"gpt-4o", + "stream":true + } + contains: | + https://www.insight.stacklok.com/report/pypi/invokehttp + does_not_contain: | + import invokehttp + llamacpp_chat: name: LlamaCPP Chat provider: llamacpp @@ -30,7 +71,7 @@ testcases: "stream":true, "temperature":0 } - expected: | + likes: | Hello! How can I assist you today? llamacpp_fim: @@ -46,7 +87,7 @@ testcases: "stop": ["<|endoftext|>", "<|fim_prefix|>", "<|fim_middle|>", "<|fim_suffix|>", "<|fim_pad|>", "<|repo_name|>", "<|file_sep|>", "<|im_start|>", "<|im_end|>", "/src/", "#- coding: utf-8", "```"], "prompt":"<|fim_prefix|>\n# codegate/test.py\nimport invokehttp\nimport requests\n\nkey = \"mysecret-key\"\n\ndef call_api():\n <|fim_suffix|>\n\n\ndata = {'key1': 'test1', 'key2': 'test2'}\nresponse = call_api('http://localhost:8080', method='post', data='data')\n<|fim_middle|>" } - expected: | + likes: | url = 'http://localhost:8080' headers = {'Authorization': f'Bearer {key}'} response = requests.get(url, headers=headers) @@ -73,7 +114,7 @@ testcases: "stream":true, "temperature":0 } - expected: | + likes: | Hello! How can I assist you today? openai_fim: @@ -99,7 +140,7 @@ testcases: "```" ] } - expected: | + likes: | response = requests.post('http://localhost:8080', json=data, headers={'Authorization': f'Bearer {key}'}) vllm_chat: @@ -123,7 +164,7 @@ testcases: "stream":true, "temperature":0 } - expected: | + likes: | Hello! How can I assist you today? If you have any questions about software security, package analysis, or need guidance on secure coding practices, feel free to ask. vllm_fim: @@ -139,7 +180,7 @@ testcases: "stop": ["<|endoftext|>", "<|fim_prefix|>", "<|fim_middle|>", "<|fim_suffix|>", "<|fim_pad|>", "<|repo_name|>", "<|file_sep|>", "<|im_start|>", "<|im_end|>", "/src/", "#- coding: utf-8", "```"], "prompt":"<|fim_prefix|>\n# codegate/test.py\nimport invokehttp\nimport requests\n\nkey = \"mysecret-key\"\n\ndef call_api():\n <|fim_suffix|>\n\n\ndata = {'key1': 'test1', 'key2': 'test2'}\nresponse = call_api('http://localhost:8080', method='post', data='data')\n<|fim_middle|>" } - expected: | + likes: | # Create an instance of the InvokeHTTP class invoke = invokehttp.InvokeHTTP(key) @@ -175,7 +216,7 @@ testcases: "stream":true, "temperature":0 } - expected: | + likes: | Hello! I'm CodeGate, your security-focused AI assistant. I can help you with: - Software security analysis and reviews @@ -216,7 +257,7 @@ testcases: ], "system": "" } - expected: | + likes: | def call_api(url, method='get', data=None): if method.lower() == 'get': return requests.get(url) @@ -246,7 +287,7 @@ testcases: "stream":true, "temperature":0 } - expected: | + likes: | Hello! How can I assist you today? If you have any questions or need guidance on secure coding practices, software security, package analysis, or anything else related to cybersecurity, feel free to ask! ollama_fim: @@ -273,7 +314,7 @@ testcases: ], "prompt":"<|fim_prefix|>\n# codegate/test.py\nimport invokehttp\nimport requests\n\nkey = \"mysecret-key\"\n\ndef call_api():\n <|fim_suffix|>\n\n\ndata = {'key1': 'test1', 'key2': 'test2'}\nresponse = call_api('http://localhost:8080', method='post', data='data')\n<|fim_middle|>" } - expected: | + likes: | ```python import invokehttp import requests