1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
|
import asyncio
from typing import Optional
from contextlib import AsyncExitStack
from mcp import ClientSession, StdioServerParameters, stdio_client
from mcp.client.sse import sse_client
from openai import OpenAI
import json
# async def run_client(query: str):
# client = MCPClient()
# try:
# await client.connect_to_server()
# response = await client.process_query(query)
# return response # 返回处理结果
# finally:
# await client.cleanup()
class MCPClient:
def __init__(self):
self.sessions = [] # 多个 MCP 连接
self.tool_map = {} # 工具名 -> session 映射
self.exit_stack = AsyncExitStack()
self.openai_api_key = "xxxxxx"
self.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
self.model = "xxxxxx"
self.client = OpenAI(api_key=self.openai_api_key, base_url=self.base_url)
async def connect_to_server(self, config: str):
data = json.loads(config)
command = self.find_key_in_dict(data, "command")
if command is not None:
args = self.find_key_in_dict(data, "args")
env = self.find_key_in_dict(data, "env")
tools, session = await self.stdio_connect_to_server(command, args, env)
else:
url = self.find_key_in_dict(data, "url")
tools, session = await self.sse_connect_to_server(url)
self.sessions.append(session)
# 建立工具 -> session 映射
for tool in tools:
self.tool_map[tool.name] = session
print("\nConnected to server with tools:", [tool.name for tool in tools])
return tools
async def process_query(self, query: str) -> str:
messages = [
{"role": "system", "content": "你是一个智能助手,帮助用户回答问题。"},
{"role": "user", "content": query}
]
all_tools = []
for tool_name, session in self.tool_map.items():
# 每个工具都从其对应 session 获取定义
response = await session.list_tools()
for tool in response.tools:
if tool.name == tool_name:
all_tools.append({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.inputSchema
}
})
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
tools=all_tools
)
content = response.choices[0]
if content.finish_reason == "tool_calls":
tool_call = content.message.tool_calls[0]
# 模型返回的工具名称
function_name = tool_call.function.name
# 模型返回的参数
function_args = json.loads(tool_call.function.arguments)
# 关键:调用正确 session 的工具
if function_name not in self.tool_map:
raise Exception(f"Tool {function_name} not found in any session")
# 执行工具拿到结果
session = self.tool_map[function_name]
result = await session.call_tool(function_name, function_args)
print(f"\n\n[Calling tool {function_name} with args {function_args}]\n\n")
result_content = result.content[0].text
messages.append(content.message.model_dump())
messages.append({
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": result_content,
})
# 将上面的结果再返回给大模型用于生产最终的结果
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
tools=all_tools
)
return response.choices[0].message.content.strip()
return content.message.content.strip()
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
"""Stdio方式"""
async def stdio_connect_to_server(self, command, args, env):
if env is not None:
server_params = StdioServerParameters(
command=command,
args=args,
env=env,
)
else:
server_params = StdioServerParameters(
command=command,
args=args,
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
stdio, write = stdio_transport
session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
await session.initialize()
response = await session.list_tools()
return response.tools, session
"""SSE方式"""
async def sse_connect_to_server(self, url):
stdio_transport = await self.exit_stack.enter_async_context(sse_client(url))
stdio, write = stdio_transport
session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
await session.initialize()
response = await session.list_tools()
return response.tools, session
"""递归函数来查找目标key"""
def find_key_in_dict(self, d, target_key):
if isinstance(d, dict): # 如果是字典
for key, value in d.items():
if key == target_key: # 找到目标 key
return value
if isinstance(value, dict): # 值是字典,递归查找
result = self.find_key_in_dict(value, target_key)
if result is not None:
return result
elif isinstance(value, list): # 值是列表,遍历列表
for item in value:
result = self.find_key_in_dict(item, target_key)
if result is not None:
return result
elif isinstance(d, list): # 如果是列表
for item in d:
result = self.find_key_in_dict(item, target_key)
if result is not None:
return result
return None
|