diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index 0d3c372ce..e93b95c90 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -1,4 +1,5 @@ from .client.session import ClientSession +from .client.session_group import ClientSessionGroup from .client.stdio import StdioServerParameters, stdio_client from .server.session import ServerSession from .server.stdio import stdio_server @@ -63,6 +64,7 @@ "ClientRequest", "ClientResult", "ClientSession", + "ClientSessionGroup", "CreateMessageRequest", "CreateMessageResult", "ErrorData", diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py new file mode 100644 index 000000000..c23f2523e --- /dev/null +++ b/src/mcp/client/session_group.py @@ -0,0 +1,372 @@ +""" +SessionGroup concurrently manages multiple MCP session connections. + +Tools, resources, and prompts are aggregated across servers. Servers may +be connected to or disconnected from at any point after initialization. + +This abstractions can handle naming collisions using a custom user-provided +hook. +""" + +import contextlib +import logging +from collections.abc import Callable +from datetime import timedelta +from types import TracebackType +from typing import Any, TypeAlias + +import anyio +from pydantic import BaseModel +from typing_extensions import Self + +import mcp +from mcp import types +from mcp.client.sse import sse_client +from mcp.client.stdio import StdioServerParameters +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.exceptions import McpError + + +class SseServerParameters(BaseModel): + """Parameters for intializing a sse_client.""" + + # The endpoint URL. + url: str + + # Optional headers to include in requests. + headers: dict[str, Any] | None = None + + # HTTP timeout for regular operations. + timeout: float = 5 + + # Timeout for SSE read operations. + sse_read_timeout: float = 60 * 5 + + +class StreamableHttpParameters(BaseModel): + """Parameters for intializing a streamablehttp_client.""" + + # The endpoint URL. + url: str + + # Optional headers to include in requests. + headers: dict[str, Any] | None = None + + # HTTP timeout for regular operations. + timeout: timedelta = timedelta(seconds=30) + + # Timeout for SSE read operations. + sse_read_timeout: timedelta = timedelta(seconds=60 * 5) + + # Close the client session when the transport closes. + terminate_on_close: bool = True + + +ServerParameters: TypeAlias = ( + StdioServerParameters | SseServerParameters | StreamableHttpParameters +) + + +class ClientSessionGroup: + """Client for managing connections to multiple MCP servers. + + This class is responsible for encapsulating management of server connections. + It aggregates tools, resources, and prompts from all connected servers. + + For auxiliary handlers, such as resource subscription, this is delegated to + the client and can be accessed via the session. + + Example Usage: + name_fn = lambda name, server_info: f"{(server_info.name)}-{name}" + async with ClientSessionGroup(component_name_hook=name_fn) as group: + for server_params in server_params: + group.connect_to_server(server_param) + ... + + """ + + class _ComponentNames(BaseModel): + """Used for reverse index to find components.""" + + prompts: set[str] = set() + resources: set[str] = set() + tools: set[str] = set() + + # Standard MCP components. + _prompts: dict[str, types.Prompt] + _resources: dict[str, types.Resource] + _tools: dict[str, types.Tool] + + # Client-server connection management. + _sessions: dict[mcp.ClientSession, _ComponentNames] + _tool_to_session: dict[str, mcp.ClientSession] + _exit_stack: contextlib.AsyncExitStack + _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] + + # Optional fn consuming (component_name, serverInfo) for custom names. + # This is provide a means to mitigate naming conflicts across servers. + # Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}" + _ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str] + _component_name_hook: _ComponentNameHook | None + + def __init__( + self, + exit_stack: contextlib.AsyncExitStack | None = None, + component_name_hook: _ComponentNameHook | None = None, + ) -> None: + """Initializes the MCP client.""" + + self._tools = {} + self._resources = {} + self._prompts = {} + + self._sessions = {} + self._tool_to_session = {} + if exit_stack is None: + self._exit_stack = contextlib.AsyncExitStack() + self._owns_exit_stack = True + else: + self._exit_stack = exit_stack + self._owns_exit_stack = False + self._session_exit_stacks = {} + self._component_name_hook = component_name_hook + + async def __aenter__(self) -> Self: + # Enter the exit stack only if we created it ourselves + if self._owns_exit_stack: + await self._exit_stack.__aenter__() + return self + + async def __aexit__( + self, + _exc_type: type[BaseException] | None, + _exc_val: BaseException | None, + _exc_tb: TracebackType | None, + ) -> bool | None: + """Closes session exit stacks and main exit stack upon completion.""" + + # Concurrently close session stacks. + async with anyio.create_task_group() as tg: + for exit_stack in self._session_exit_stacks.values(): + tg.start_soon(exit_stack.aclose) + + # Only close the main exit stack if we created it + if self._owns_exit_stack: + await self._exit_stack.aclose() + + @property + def sessions(self) -> list[mcp.ClientSession]: + """Returns the list of sessions being managed.""" + return list(self._sessions.keys()) + + @property + def prompts(self) -> dict[str, types.Prompt]: + """Returns the prompts as a dictionary of names to prompts.""" + return self._prompts + + @property + def resources(self) -> dict[str, types.Resource]: + """Returns the resources as a dictionary of names to resources.""" + return self._resources + + @property + def tools(self) -> dict[str, types.Tool]: + """Returns the tools as a dictionary of names to tools.""" + return self._tools + + async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: + """Executes a tool given its name and arguments.""" + session = self._tool_to_session[name] + session_tool_name = self.tools[name].name + return await session.call_tool(session_tool_name, args) + + async def disconnect_from_server(self, session: mcp.ClientSession) -> None: + """Disconnects from a single MCP server.""" + + session_known_for_components = session in self._sessions + session_known_for_stack = session in self._session_exit_stacks + + if not session_known_for_components and not session_known_for_stack: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message="Provided session is not managed or already disconnected.", + ) + ) + + if session_known_for_components: + component_names = self._sessions.pop(session) # Pop from _sessions tracking + + # Remove prompts associated with the session. + for name in component_names.prompts: + if name in self._prompts: + del self._prompts[name] + # Remove resources associated with the session. + for name in component_names.resources: + if name in self._resources: + del self._resources[name] + # Remove tools associated with the session. + for name in component_names.tools: + if name in self._tools: + del self._tools[name] + if name in self._tool_to_session: + del self._tool_to_session[name] + + # Clean up the session's resources via its dedicated exit stack + if session_known_for_stack: + session_stack_to_close = self._session_exit_stacks.pop(session) + await session_stack_to_close.aclose() + + async def connect_with_session( + self, server_info: types.Implementation, session: mcp.ClientSession + ) -> mcp.ClientSession: + """Connects to a single MCP server.""" + await self._aggregate_components(server_info, session) + return session + + async def connect_to_server( + self, + server_params: ServerParameters, + ) -> mcp.ClientSession: + """Connects to a single MCP server.""" + server_info, session = await self._establish_session(server_params) + return await self.connect_with_session(server_info, session) + + async def _establish_session( + self, server_params: ServerParameters + ) -> tuple[types.Implementation, mcp.ClientSession]: + """Establish a client session to an MCP server.""" + + session_stack = contextlib.AsyncExitStack() + try: + # Create read and write streams that facilitate io with the server. + if isinstance(server_params, StdioServerParameters): + client = mcp.stdio_client(server_params) + read, write = await session_stack.enter_async_context(client) + elif isinstance(server_params, SseServerParameters): + client = sse_client( + url=server_params.url, + headers=server_params.headers, + timeout=server_params.timeout, + sse_read_timeout=server_params.sse_read_timeout, + ) + read, write = await session_stack.enter_async_context(client) + else: + client = streamablehttp_client( + url=server_params.url, + headers=server_params.headers, + timeout=server_params.timeout, + sse_read_timeout=server_params.sse_read_timeout, + terminate_on_close=server_params.terminate_on_close, + ) + read, write, _ = await session_stack.enter_async_context(client) + + session = await session_stack.enter_async_context( + mcp.ClientSession(read, write) + ) + result = await session.initialize() + + # Session successfully initialized. + # Store its stack and register the stack with the main group stack. + self._session_exit_stacks[session] = session_stack + # session_stack itself becomes a resource managed by the + # main _exit_stack. + await self._exit_stack.enter_async_context(session_stack) + + return result.serverInfo, session + except Exception: + # If anything during this setup fails, ensure the session-specific + # stack is closed. + await session_stack.aclose() + raise + + async def _aggregate_components( + self, server_info: types.Implementation, session: mcp.ClientSession + ) -> None: + """Aggregates prompts, resources, and tools from a given session.""" + + # Create a reverse index so we can find all prompts, resources, and + # tools belonging to this session. Used for removing components from + # the session group via self.disconnect_from_server. + component_names = self._ComponentNames() + + # Temporary components dicts. We do not want to modify the aggregate + # lists in case of an intermediate failure. + prompts_temp: dict[str, types.Prompt] = {} + resources_temp: dict[str, types.Resource] = {} + tools_temp: dict[str, types.Tool] = {} + tool_to_session_temp: dict[str, mcp.ClientSession] = {} + + # Query the server for its prompts and aggregate to list. + try: + prompts = (await session.list_prompts()).prompts + for prompt in prompts: + name = self._component_name(prompt.name, server_info) + prompts_temp[name] = prompt + component_names.prompts.add(name) + except McpError as err: + logging.warning(f"Could not fetch prompts: {err}") + + # Query the server for its resources and aggregate to list. + try: + resources = (await session.list_resources()).resources + for resource in resources: + name = self._component_name(resource.name, server_info) + resources_temp[name] = resource + component_names.resources.add(name) + except McpError as err: + logging.warning(f"Could not fetch resources: {err}") + + # Query the server for its tools and aggregate to list. + try: + tools = (await session.list_tools()).tools + for tool in tools: + name = self._component_name(tool.name, server_info) + tools_temp[name] = tool + tool_to_session_temp[name] = session + component_names.tools.add(name) + except McpError as err: + logging.warning(f"Could not fetch tools: {err}") + + # Clean up exit stack for session if we couldn't retrieve anything + # from the server. + if not any((prompts_temp, resources_temp, tools_temp)): + del self._session_exit_stacks[session] + + # Check for duplicates. + matching_prompts = prompts_temp.keys() & self._prompts.keys() + if matching_prompts: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{matching_prompts} already exist in group prompts.", + ) + ) + matching_resources = resources_temp.keys() & self._resources.keys() + if matching_resources: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{matching_resources} already exist in group resources.", + ) + ) + matching_tools = tools_temp.keys() & self._tools.keys() + if matching_tools: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{matching_tools} already exist in group tools.", + ) + ) + + # Aggregate components. + self._sessions[session] = component_names + self._prompts.update(prompts_temp) + self._resources.update(resources_temp) + self._tools.update(tools_temp) + self._tool_to_session.update(tool_to_session_temp) + + def _component_name(self, name: str, server_info: types.Implementation) -> str: + if self._component_name_hook: + return self._component_name_hook(name, server_info) + return name diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py new file mode 100644 index 000000000..924ef7a06 --- /dev/null +++ b/tests/client/test_session_group.py @@ -0,0 +1,397 @@ +import contextlib +from unittest import mock + +import pytest + +import mcp +from mcp import types +from mcp.client.session_group import ( + ClientSessionGroup, + SseServerParameters, + StreamableHttpParameters, +) +from mcp.client.stdio import StdioServerParameters +from mcp.shared.exceptions import McpError + + +@pytest.fixture +def mock_exit_stack(): + """Fixture for a mocked AsyncExitStack.""" + # Use unittest.mock.Mock directly if needed, or just a plain object + # if only attribute access/existence is needed. + # For AsyncExitStack, Mock or MagicMock is usually fine. + return mock.MagicMock(spec=contextlib.AsyncExitStack) + + +@pytest.mark.anyio +class TestClientSessionGroup: + def test_init(self): + mcp_session_group = ClientSessionGroup() + assert not mcp_session_group._tools + assert not mcp_session_group._resources + assert not mcp_session_group._prompts + assert not mcp_session_group._tool_to_session + + def test_component_properties(self): + # --- Mock Dependencies --- + mock_prompt = mock.Mock() + mock_resource = mock.Mock() + mock_tool = mock.Mock() + + # --- Prepare Session Group --- + mcp_session_group = ClientSessionGroup() + mcp_session_group._prompts = {"my_prompt": mock_prompt} + mcp_session_group._resources = {"my_resource": mock_resource} + mcp_session_group._tools = {"my_tool": mock_tool} + + # --- Assertions --- + assert mcp_session_group.prompts == {"my_prompt": mock_prompt} + assert mcp_session_group.resources == {"my_resource": mock_resource} + assert mcp_session_group.tools == {"my_tool": mock_tool} + + async def test_call_tool(self): + # --- Mock Dependencies --- + mock_session = mock.AsyncMock() + + # --- Prepare Session Group --- + def hook(name, server_info): + return f"{(server_info.name)}-{name}" + + mcp_session_group = ClientSessionGroup(component_name_hook=hook) + mcp_session_group._tools = { + "server1-my_tool": types.Tool(name="my_tool", inputSchema={}) + } + mcp_session_group._tool_to_session = {"server1-my_tool": mock_session} + text_content = types.TextContent(type="text", text="OK") + mock_session.call_tool.return_value = types.CallToolResult( + content=[text_content] + ) + + # --- Test Execution --- + result = await mcp_session_group.call_tool( + name="server1-my_tool", + args={ + "name": "value1", + "args": {}, + }, + ) + + # --- Assertions --- + assert result.content == [text_content] + mock_session.call_tool.assert_called_once_with( + "my_tool", + {"name": "value1", "args": {}}, + ) + + async def test_connect_to_server(self, mock_exit_stack): + """Test connecting to a server and aggregating components.""" + # --- Mock Dependencies --- + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "TestServer1" + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + mock_tool1 = mock.Mock(spec=types.Tool) + mock_tool1.name = "tool_a" + mock_resource1 = mock.Mock(spec=types.Resource) + mock_resource1.name = "resource_b" + mock_prompt1 = mock.Mock(spec=types.Prompt) + mock_prompt1.name = "prompt_c" + mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1]) + mock_session.list_resources.return_value = mock.AsyncMock( + resources=[mock_resource1] + ) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1]) + + # --- Test Execution --- + group = ClientSessionGroup(exit_stack=mock_exit_stack) + with mock.patch.object( + group, "_establish_session", return_value=(mock_server_info, mock_session) + ): + await group.connect_to_server(StdioServerParameters(command="test")) + + # --- Assertions --- + assert mock_session in group._sessions + assert len(group.tools) == 1 + assert "tool_a" in group.tools + assert group.tools["tool_a"] == mock_tool1 + assert group._tool_to_session["tool_a"] == mock_session + assert len(group.resources) == 1 + assert "resource_b" in group.resources + assert group.resources["resource_b"] == mock_resource1 + assert len(group.prompts) == 1 + assert "prompt_c" in group.prompts + assert group.prompts["prompt_c"] == mock_prompt1 + mock_session.list_tools.assert_awaited_once() + mock_session.list_resources.assert_awaited_once() + mock_session.list_prompts.assert_awaited_once() + + async def test_connect_to_server_with_name_hook(self, mock_exit_stack): + """Test connecting with a component name hook.""" + # --- Mock Dependencies --- + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "HookServer" + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + mock_tool = mock.Mock(spec=types.Tool) + mock_tool.name = "base_tool" + mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool]) + mock_session.list_resources.return_value = mock.AsyncMock(resources=[]) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[]) + + # --- Test Setup --- + def name_hook(name: str, server_info: types.Implementation) -> str: + return f"{server_info.name}.{name}" + + # --- Test Execution --- + group = ClientSessionGroup( + exit_stack=mock_exit_stack, component_name_hook=name_hook + ) + with mock.patch.object( + group, "_establish_session", return_value=(mock_server_info, mock_session) + ): + await group.connect_to_server(StdioServerParameters(command="test")) + + # --- Assertions --- + assert mock_session in group._sessions + assert len(group.tools) == 1 + expected_tool_name = "HookServer.base_tool" + assert expected_tool_name in group.tools + assert group.tools[expected_tool_name] == mock_tool + assert group._tool_to_session[expected_tool_name] == mock_session + + async def test_disconnect_from_server(self): # No mock arguments needed + """Test disconnecting from a server.""" + # --- Test Setup --- + group = ClientSessionGroup() + server_name = "ServerToDisconnect" + + # Manually populate state using standard mocks + mock_session1 = mock.MagicMock(spec=mcp.ClientSession) + mock_session2 = mock.MagicMock(spec=mcp.ClientSession) + mock_tool1 = mock.Mock(spec=types.Tool) + mock_tool1.name = "tool1" + mock_resource1 = mock.Mock(spec=types.Resource) + mock_resource1.name = "res1" + mock_prompt1 = mock.Mock(spec=types.Prompt) + mock_prompt1.name = "prm1" + mock_tool2 = mock.Mock(spec=types.Tool) + mock_tool2.name = "tool2" + mock_component_named_like_server = mock.Mock() + mock_session = mock.Mock(spec=mcp.ClientSession) + + group._tools = { + "tool1": mock_tool1, + "tool2": mock_tool2, + server_name: mock_component_named_like_server, + } + group._tool_to_session = { + "tool1": mock_session1, + "tool2": mock_session2, + server_name: mock_session1, + } + group._resources = { + "res1": mock_resource1, + server_name: mock_component_named_like_server, + } + group._prompts = { + "prm1": mock_prompt1, + server_name: mock_component_named_like_server, + } + group._sessions = { + mock_session: ClientSessionGroup._ComponentNames( + prompts=set({"prm1"}), + resources=set({"res1"}), + tools=set({"tool1", "tool2"}), + ) + } + + # --- Assertions --- + assert mock_session in group._sessions + assert "tool1" in group._tools + assert "tool2" in group._tools + assert "res1" in group._resources + assert "prm1" in group._prompts + + # --- Test Execution --- + await group.disconnect_from_server(mock_session) + + # --- Assertions --- + assert mock_session not in group._sessions + assert "tool1" not in group._tools + assert "tool2" not in group._tools + assert "res1" not in group._resources + assert "prm1" not in group._prompts + + async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_stack): + """Test McpError raised when connecting a server with a dup name.""" + # --- Setup Pre-existing State --- + group = ClientSessionGroup(exit_stack=mock_exit_stack) + existing_tool_name = "shared_tool" + # Manually add a tool to simulate a previous connection + group._tools[existing_tool_name] = mock.Mock(spec=types.Tool) + group._tools[existing_tool_name].name = existing_tool_name + # Need a dummy session associated with the existing tool + mock_session = mock.MagicMock(spec=mcp.ClientSession) + group._tool_to_session[existing_tool_name] = mock_session + group._session_exit_stacks[mock_session] = mock.Mock( + spec=contextlib.AsyncExitStack + ) + + # --- Mock New Connection Attempt --- + mock_server_info_new = mock.Mock(spec=types.Implementation) + mock_server_info_new.name = "ServerWithDuplicate" + mock_session_new = mock.AsyncMock(spec=mcp.ClientSession) + + # Configure the new session to return a tool with the *same name* + duplicate_tool = mock.Mock(spec=types.Tool) + duplicate_tool.name = existing_tool_name + mock_session_new.list_tools.return_value = mock.AsyncMock( + tools=[duplicate_tool] + ) + # Keep other lists empty for simplicity + mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[]) + mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[]) + + # --- Test Execution and Assertion --- + with pytest.raises(McpError) as excinfo: + with mock.patch.object( + group, + "_establish_session", + return_value=(mock_server_info_new, mock_session_new), + ): + await group.connect_to_server(StdioServerParameters(command="test")) + + # Assert details about the raised error + assert excinfo.value.error.code == types.INVALID_PARAMS + assert existing_tool_name in excinfo.value.error.message + assert "already exist " in excinfo.value.error.message + + # Verify the duplicate tool was *not* added again (state should be unchanged) + assert len(group._tools) == 1 # Should still only have the original + assert ( + group._tools[existing_tool_name] is not duplicate_tool + ) # Ensure it's the original mock + + # No patching needed here + async def test_disconnect_non_existent_server(self): + """Test disconnecting a server that isn't connected.""" + session = mock.Mock(spec=mcp.ClientSession) + group = ClientSessionGroup() + with pytest.raises(McpError): + await group.disconnect_from_server(session) + + @pytest.mark.parametrize( + "server_params_instance, client_type_name, patch_target_for_client_func", + [ + ( + StdioServerParameters(command="test_stdio_cmd"), + "stdio", + "mcp.client.session_group.mcp.stdio_client", + ), + ( + SseServerParameters(url="http://test.com/sse", timeout=10), + "sse", + "mcp.client.session_group.sse_client", + ), # url, headers, timeout, sse_read_timeout + ( + StreamableHttpParameters( + url="http://test.com/stream", terminate_on_close=False + ), + "streamablehttp", + "mcp.client.session_group.streamablehttp_client", + ), # url, headers, timeout, sse_read_timeout, terminate_on_close + ], + ) + async def test_establish_session_parameterized( + self, + server_params_instance, + client_type_name, # Just for clarity or conditional logic if needed + patch_target_for_client_func, + ): + with mock.patch( + "mcp.client.session_group.mcp.ClientSession" + ) as mock_ClientSession_class: + with mock.patch(patch_target_for_client_func) as mock_specific_client_func: + mock_client_cm_instance = mock.AsyncMock( + name=f"{client_type_name}ClientCM" + ) + mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read") + mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write") + + # streamablehttp_client's __aenter__ returns three values + if client_type_name == "streamablehttp": + mock_extra_stream_val = mock.AsyncMock(name="StreamableExtra") + mock_client_cm_instance.__aenter__.return_value = ( + mock_read_stream, + mock_write_stream, + mock_extra_stream_val, + ) + else: + mock_client_cm_instance.__aenter__.return_value = ( + mock_read_stream, + mock_write_stream, + ) + + mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None) + mock_specific_client_func.return_value = mock_client_cm_instance + + # --- Mock mcp.ClientSession (class) --- + # mock_ClientSession_class is already provided by the outer patch + mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM") + mock_ClientSession_class.return_value = mock_raw_session_cm + + mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance") + mock_raw_session_cm.__aenter__.return_value = mock_entered_session + mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None) + + # Mock session.initialize() + mock_initialize_result = mock.AsyncMock(name="InitializeResult") + mock_initialize_result.serverInfo = types.Implementation( + name="foo", version="1" + ) + mock_entered_session.initialize.return_value = mock_initialize_result + + # --- Test Execution --- + group = ClientSessionGroup() + returned_server_info = None + returned_session = None + + async with contextlib.AsyncExitStack() as stack: + group._exit_stack = stack + ( + returned_server_info, + returned_session, + ) = await group._establish_session(server_params_instance) + + # --- Assertions --- + # 1. Assert the correct specific client function was called + if client_type_name == "stdio": + mock_specific_client_func.assert_called_once_with( + server_params_instance + ) + elif client_type_name == "sse": + mock_specific_client_func.assert_called_once_with( + url=server_params_instance.url, + headers=server_params_instance.headers, + timeout=server_params_instance.timeout, + sse_read_timeout=server_params_instance.sse_read_timeout, + ) + elif client_type_name == "streamablehttp": + mock_specific_client_func.assert_called_once_with( + url=server_params_instance.url, + headers=server_params_instance.headers, + timeout=server_params_instance.timeout, + sse_read_timeout=server_params_instance.sse_read_timeout, + terminate_on_close=server_params_instance.terminate_on_close, + ) + + mock_client_cm_instance.__aenter__.assert_awaited_once() + + # 2. Assert ClientSession was called correctly + mock_ClientSession_class.assert_called_once_with( + mock_read_stream, mock_write_stream + ) + mock_raw_session_cm.__aenter__.assert_awaited_once() + mock_entered_session.initialize.assert_awaited_once() + + # 3. Assert returned values + assert returned_server_info is mock_initialize_result.serverInfo + assert returned_session is mock_entered_session