-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Open
Labels
clientRelated to the FastMCP client SDK or client-side functionality.Related to the FastMCP client SDK or client-side functionality.enhancementImprovement to existing functionality. For issues and smaller PR improvements.Improvement to existing functionality. For issues and smaller PR improvements.
Description
Enhancement
off the top of my head it will probably look vaguely like:
"""Google GenAI sampling handler with tool support for FastMCP 3.0."""
from collections.abc import Sequence
from typing import Any
from uuid import uuid4
from google.genai import Client as GoogleGenaiClient
from google.genai.types import (
Candidate,
Content,
FunctionCall,
FunctionCallingConfig,
FunctionCallingConfigMode,
FunctionDeclaration,
FunctionResponse,
GenerateContentConfig,
GenerateContentResponse,
ModelContent,
Part,
ThinkingConfig,
ToolConfig,
UserContent,
)
from google.genai.types import Tool as GoogleTool
from mcp import ClientSession, ServerSession
from mcp.shared.context import LifespanContextT, RequestContext
from mcp.types import (
AudioContent,
CreateMessageResult,
CreateMessageResultWithTools,
ImageContent,
ModelPreferences,
SamplingMessage,
SamplingMessageContentBlock,
StopReason,
TextContent,
ToolChoice,
ToolResultContent,
ToolUseContent,
)
from mcp.types import CreateMessageRequestParams as SamplingParams
from mcp.types import Tool as MCPTool
class GoogleGenaiSamplingHandler:
"""Sampling handler that uses the Google GenAI API with tool support."""
def __init__(self, default_model: str, client: GoogleGenaiClient | None = None):
self.client: GoogleGenaiClient = client or GoogleGenaiClient()
self.default_model: str = default_model
async def __call__(
self,
messages: list[SamplingMessage],
params: SamplingParams,
context: RequestContext[ServerSession, LifespanContextT]
| RequestContext[ClientSession, LifespanContextT],
) -> CreateMessageResult | CreateMessageResultWithTools:
contents: list[Content] = _convert_messages_to_google_genai_content(messages)
# Convert MCP tools to Google GenAI format
google_tools: list[GoogleTool] | None = None
tool_config: ToolConfig | None = None
if params.tools:
google_tools = [_convert_tool_to_google_genai(tool) for tool in params.tools]
tool_config = _convert_tool_choice_to_google_genai(params.toolChoice)
response: GenerateContentResponse = await self.client.aio.models.generate_content(
model=self._get_model(model_preferences=params.modelPreferences),
contents=contents,
config=GenerateContentConfig(
system_instruction=params.systemPrompt,
temperature=params.temperature,
max_output_tokens=params.maxTokens,
stop_sequences=params.stopSequences,
thinking_config=ThinkingConfig(thinking_budget=200),
tools=google_tools, # type: ignore[arg-type]
tool_config=tool_config,
),
)
# Return appropriate result type based on whether tools were provided
if params.tools:
return _response_to_result_with_tools(response, self.default_model)
return _response_to_create_message_result(response, self.default_model)
def _get_model(self, model_preferences: ModelPreferences | None) -> str:
if model_preferences and model_preferences.hints:
for hint in model_preferences.hints:
if hint.name:
return hint.name
return self.default_model
def _convert_tool_to_google_genai(tool: MCPTool) -> GoogleTool:
"""Convert an MCP Tool to Google GenAI format."""
input_schema: dict[str, Any] = tool.inputSchema
properties: dict[str, Any] = input_schema.get("properties", {})
required: list[str] = input_schema.get("required", [])
# Build parameters schema with Google's type format
google_properties: dict[str, Any] = {}
for prop_name, prop_schema in properties.items():
google_properties[prop_name] = _convert_json_schema_to_google_schema(dict(prop_schema))
return GoogleTool(
function_declarations=[
FunctionDeclaration(
name=tool.name,
description=tool.description or "",
parameters={ # type: ignore[arg-type]
"type": "OBJECT",
"properties": google_properties,
"required": required,
},
)
]
)
def _convert_json_schema_to_google_schema(schema: dict[str, Any]) -> dict[str, Any]:
"""Convert JSON Schema to Google GenAI Schema format.
Handles:
- Basic types (string, integer, number, boolean, array, object)
- Nullable types via anyOf with null type
- Nested objects and arrays
"""
result: dict[str, Any] = {}
# Handle anyOf for nullable types (e.g., anyOf: [{type: string}, {type: null}])
if "anyOf" in schema:
any_of_types = schema["anyOf"]
non_null_types = [t for t in any_of_types if t.get("type") != "null"]
has_null = len(non_null_types) < len(any_of_types)
if non_null_types:
# Recursively convert the non-null type
non_null_schema = non_null_types[0]
result = _convert_json_schema_to_google_schema(non_null_schema)
if has_null:
result["nullable"] = True
# Preserve description from parent schema
if "description" in schema:
result["description"] = schema["description"]
return result
schema_type: str | None = schema.get("type")
if schema_type:
type_map: dict[str, str] = {
"string": "STRING",
"integer": "INTEGER",
"number": "NUMBER",
"boolean": "BOOLEAN",
"array": "ARRAY",
"object": "OBJECT",
}
result["type"] = type_map.get(schema_type, "STRING")
if "description" in schema:
result["description"] = schema["description"]
if "enum" in schema:
result["enum"] = schema["enum"]
if "items" in schema:
result["items"] = _convert_json_schema_to_google_schema(dict(schema["items"]))
if "properties" in schema:
result["properties"] = {
str(k): _convert_json_schema_to_google_schema(dict(v)) for k, v in dict(schema["properties"]).items()
}
if "required" in schema:
result["required"] = schema["required"]
return result
def _convert_tool_choice_to_google_genai(tool_choice: ToolChoice | None) -> ToolConfig:
"""Convert MCP ToolChoice to Google GenAI ToolConfig."""
if tool_choice is None:
return ToolConfig(function_calling_config=FunctionCallingConfig(mode=FunctionCallingConfigMode.AUTO))
if tool_choice.mode == "required":
return ToolConfig(function_calling_config=FunctionCallingConfig(mode=FunctionCallingConfigMode.ANY))
if tool_choice.mode == "none":
return ToolConfig(function_calling_config=FunctionCallingConfig(mode=FunctionCallingConfigMode.NONE))
# Default to AUTO for "auto" or any other value
return ToolConfig(function_calling_config=FunctionCallingConfig(mode=FunctionCallingConfigMode.AUTO))
def _sampling_content_to_google_genai_part(
content: TextContent | ImageContent | AudioContent | ToolUseContent | ToolResultContent,
) -> Part:
"""Convert MCP content to Google GenAI Part."""
if isinstance(content, TextContent):
return Part(text=content.text)
if isinstance(content, ToolUseContent):
return Part(
function_call=FunctionCall(
name=content.name,
args=content.input,
)
)
if isinstance(content, ToolResultContent):
# Extract text from tool result content
result_text = ""
if content.content:
for item in content.content:
if isinstance(item, TextContent):
result_text += item.text
# Extract function name from toolUseId
# Our IDs are formatted as "{function_name}_{uuid8}", so extract the name
tool_use_id = content.toolUseId
if "_" in tool_use_id:
# Split and rejoin all but the last part (the UUID suffix)
parts = tool_use_id.rsplit("_", 1)
function_name = parts[0]
else:
# Fallback: use the full ID as the name
function_name = tool_use_id
return Part(
function_response=FunctionResponse(
name=function_name,
response={"result": result_text},
)
)
msg = f"Unsupported content type: {type(content)}"
raise ValueError(msg)
def _convert_messages_to_google_genai_content(
messages: Sequence[SamplingMessage],
) -> list[Content]:
"""Convert MCP messages to Google GenAI content."""
google_messages: list[Content] = []
for message in messages:
content = message.content
# Handle list content (tool calls + results)
if isinstance(content, list):
parts: list[Part] = []
for item in content:
parts.append(_sampling_content_to_google_genai_part(item)) # type: ignore[arg-type]
if message.role == "user":
google_messages.append(UserContent(parts=parts))
else:
google_messages.append(ModelContent(parts=parts))
continue
# Handle single content item
part = _sampling_content_to_google_genai_part(content) # type: ignore[arg-type]
if message.role == "user":
google_messages.append(UserContent(parts=[part]))
elif message.role == "assistant":
google_messages.append(ModelContent(parts=[part]))
else:
msg = f"Invalid message role: {message.role}"
raise ValueError(msg)
return google_messages
def _get_candidate_from_response(response: GenerateContentResponse) -> Candidate:
"""Extract the first candidate from a response."""
if response.candidates and response.candidates[0]:
return response.candidates[0]
msg = "No candidate in response from completion."
raise ValueError(msg)
def _response_to_create_message_result(
response: GenerateContentResponse,
model: str,
) -> CreateMessageResult:
"""Convert Google GenAI response to CreateMessageResult (no tools)."""
if not (text := response.text):
candidate = _get_candidate_from_response(response)
msg = f"No content in response: {candidate.finish_reason}"
raise ValueError(msg)
return CreateMessageResult(
content=TextContent(type="text", text=text),
role="assistant",
model=model,
)
def _response_to_result_with_tools(
response: GenerateContentResponse,
model: str,
) -> CreateMessageResultWithTools:
"""Convert Google GenAI response to CreateMessageResultWithTools."""
candidate = _get_candidate_from_response(response)
# Determine stop reason and check for function calls
stop_reason: StopReason
finish_reason = candidate.finish_reason
has_function_calls = False
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
has_function_calls = True
break
if has_function_calls:
stop_reason = "toolUse"
elif finish_reason == "STOP":
stop_reason = "endTurn"
elif finish_reason == "MAX_TOKENS":
stop_reason = "maxTokens"
else:
stop_reason = "endTurn"
# Build content list
content: list[SamplingMessageContentBlock] = []
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, "text") and part.text:
content.append(TextContent(type="text", text=part.text))
elif hasattr(part, "function_call") and part.function_call:
fc = part.function_call
fc_name: str = fc.name or "unknown"
content.append(
ToolUseContent(
type="tool_use",
id=f"{fc_name}_{uuid4().hex[:8]}", # Generate unique ID
name=fc_name,
input=dict(fc.args) if fc.args else {},
)
)
if not content:
raise ValueError("No content in response from completion")
return CreateMessageResultWithTools(
content=content,
role="assistant",
model=model,
stopReason=stop_reason,
)
but who knows
Metadata
Metadata
Assignees
Labels
clientRelated to the FastMCP client SDK or client-side functionality.Related to the FastMCP client SDK or client-side functionality.enhancementImprovement to existing functionality. For issues and smaller PR improvements.Improvement to existing functionality. For issues and smaller PR improvements.