|
4 | 4 | import sqlite3
|
5 | 5 | import uuid
|
6 | 6 | from pathlib import Path
|
7 |
| -from typing import List, Optional, Type |
| 7 | +from typing import List, Optional, Tuple, Type |
8 | 8 |
|
9 | 9 | import numpy as np
|
10 | 10 | import sqlite_vec_sl_tmp
|
@@ -716,6 +716,61 @@ async def get_prompts_with_output(
|
716 | 716 | )
|
717 | 717 | return prompts
|
718 | 718 |
|
| 719 | + def _build_prompt_query( |
| 720 | + self, |
| 721 | + base_query: str, |
| 722 | + workspace_id: str, |
| 723 | + filter_by_ids: Optional[List[str]] = None, |
| 724 | + filter_by_alert_trigger_categories: Optional[List[str]] = None, |
| 725 | + filter_by_alert_trigger_types: Optional[List[str]] = None, |
| 726 | + offset: Optional[int] = None, |
| 727 | + page_size: Optional[int] = None, |
| 728 | + ) -> Tuple[str, dict]: |
| 729 | + """ |
| 730 | + Helper method to construct SQL query and conditions for prompts based on filters. |
| 731 | +
|
| 732 | + Args: |
| 733 | + base_query: The base SQL query string with a placeholder for filter conditions. |
| 734 | + workspace_id: The ID of the workspace to fetch prompts from. |
| 735 | + filter_by_ids: Optional list of prompt IDs to filter by. |
| 736 | + filter_by_alert_trigger_categories: Optional list of alert categories to filter by. |
| 737 | + filter_by_alert_trigger_types: Optional list of alert trigger types to filter by. |
| 738 | + offset: Number of records to skip (for pagination). |
| 739 | + page_size: Number of records per page. |
| 740 | +
|
| 741 | + Returns: |
| 742 | + A tuple containing the formatted SQL query string and a dictionary of conditions. |
| 743 | + """ |
| 744 | + conditions = {"workspace_id": workspace_id} |
| 745 | + filter_conditions = [] |
| 746 | + |
| 747 | + if filter_by_alert_trigger_categories: |
| 748 | + filter_conditions.append( |
| 749 | + "AND (a.trigger_category IN :filter_by_alert_trigger_categories OR a.trigger_category IS NULL)" |
| 750 | + ) |
| 751 | + conditions["filter_by_alert_trigger_categories"] = filter_by_alert_trigger_categories |
| 752 | + |
| 753 | + if filter_by_alert_trigger_types: |
| 754 | + filter_conditions.append( |
| 755 | + "AND EXISTS (SELECT 1 FROM alerts a2 WHERE a2.prompt_id = p.id AND a2.trigger_type IN :filter_by_alert_trigger_types)" |
| 756 | + ) |
| 757 | + conditions["filter_by_alert_trigger_types"] = filter_by_alert_trigger_types |
| 758 | + |
| 759 | + if filter_by_ids: |
| 760 | + filter_conditions.append("AND p.id IN :filter_by_ids") |
| 761 | + conditions["filter_by_ids"] = filter_by_ids |
| 762 | + |
| 763 | + if offset is not None: |
| 764 | + conditions["offset"] = offset |
| 765 | + |
| 766 | + if page_size is not None: |
| 767 | + conditions["page_size"] = page_size |
| 768 | + |
| 769 | + filter_clause = " ".join(filter_conditions) |
| 770 | + query = base_query.format(filter_conditions=filter_clause) |
| 771 | + |
| 772 | + return query, conditions |
| 773 | + |
719 | 774 | async def get_prompts(
|
720 | 775 | self,
|
721 | 776 | workspace_id: str,
|
@@ -749,39 +804,19 @@ async def get_prompts(
|
749 | 804 | ORDER BY p.timestamp DESC
|
750 | 805 | LIMIT :page_size OFFSET :offset
|
751 | 806 | """
|
752 |
| - # Build conditions and filters |
753 |
| - conditions = { |
754 |
| - "workspace_id": workspace_id, |
755 |
| - "page_size": page_size, |
756 |
| - "offset": offset, |
757 |
| - } |
758 |
| - |
759 |
| - # Conditionally add filter clauses and conditions |
760 |
| - filter_conditions = [] |
761 |
| - |
762 |
| - if filter_by_alert_trigger_categories: |
763 |
| - filter_conditions.append( |
764 |
| - "AND a.trigger_category IN :filter_by_alert_trigger_categories" |
765 |
| - ) |
766 |
| - conditions["filter_by_alert_trigger_categories"] = filter_by_alert_trigger_categories |
767 |
| - |
768 |
| - if filter_by_alert_trigger_types: |
769 |
| - filter_conditions.append( |
770 |
| - "AND EXISTS (SELECT 1 FROM alerts a2 WHERE a2.prompt_id = p.id AND a2.trigger_type IN :filter_by_alert_trigger_types)" # noqa: E501 |
771 |
| - ) |
772 |
| - conditions["filter_by_alert_trigger_types"] = filter_by_alert_trigger_types |
773 |
| - |
774 |
| - if filter_by_ids: |
775 |
| - filter_conditions.append("AND p.id IN :filter_by_ids") |
776 |
| - conditions["filter_by_ids"] = filter_by_ids |
777 |
| - |
778 |
| - filter_clause = " ".join(filter_conditions) |
779 |
| - query = base_query.format(filter_conditions=filter_clause) |
780 | 807 |
|
| 808 | + query, conditions = self._build_prompt_query( |
| 809 | + base_query, |
| 810 | + workspace_id, |
| 811 | + filter_by_ids, |
| 812 | + filter_by_alert_trigger_categories, |
| 813 | + filter_by_alert_trigger_types, |
| 814 | + offset, |
| 815 | + page_size, |
| 816 | + ) |
781 | 817 | sql = text(query)
|
782 | 818 |
|
783 | 819 | # Bind optional params
|
784 |
| - |
785 | 820 | if filter_by_alert_trigger_categories:
|
786 | 821 | sql = sql.bindparams(bindparam("filter_by_alert_trigger_categories", expanding=True))
|
787 | 822 | if filter_by_alert_trigger_types:
|
@@ -813,28 +848,14 @@ async def get_total_messages_count_by_workspace_id(
|
813 | 848 | WHERE p.workspace_id = :workspace_id
|
814 | 849 | {filter_conditions}
|
815 | 850 | """
|
816 |
| - conditions = {"workspace_id": workspace_id} |
817 |
| - filter_conditions = [] |
818 |
| - |
819 |
| - if filter_by_alert_trigger_categories: |
820 |
| - filter_conditions.append( |
821 |
| - "AND a.trigger_category IN :filter_by_alert_trigger_categories" |
822 |
| - ) |
823 |
| - conditions["filter_by_alert_trigger_categories"] = filter_by_alert_trigger_categories |
824 |
| - |
825 |
| - if filter_by_alert_trigger_types: |
826 |
| - filter_conditions.append( |
827 |
| - "AND EXISTS (SELECT 1 FROM alerts a2 WHERE " |
828 |
| - "a2.prompt_id = p.id AND a2.trigger_type IN :filter_by_alert_trigger_types)" |
829 |
| - ) |
830 |
| - conditions["filter_by_alert_trigger_types"] = filter_by_alert_trigger_types |
831 | 851 |
|
832 |
| - if filter_by_ids: |
833 |
| - filter_conditions.append("AND p.id IN :filter_by_ids") |
834 |
| - conditions["filter_by_ids"] = filter_by_ids |
835 |
| - |
836 |
| - filter_clause = " ".join(filter_conditions) |
837 |
| - query = base_query.format(filter_conditions=filter_clause) |
| 852 | + query, conditions = self._build_prompt_query( |
| 853 | + base_query, |
| 854 | + workspace_id, |
| 855 | + filter_by_ids, |
| 856 | + filter_by_alert_trigger_categories, |
| 857 | + filter_by_alert_trigger_types, |
| 858 | + ) |
838 | 859 | sql = text(query)
|
839 | 860 |
|
840 | 861 | # Bind optional params
|
|
0 commit comments