1. 项目概述:为什么“模型无关”不是一句空话,而是一道必须跨过的工程门槛
你有没有遇到过这样的场景:团队刚用 GPT-4 Turbo 跑通了一个客服对话路由 Agent,客户突然要求“必须切换成国产大模型,且不能改一行业务逻辑”;或者你在做金融合规审核 Agent,昨天还在调用 Qwen2.5-72B 的函数调用能力,今天监管侧要求所有推理必须走本地部署的 DeepSeek-V3,接口协议、流式响应格式、token 计费方式、错误码体系全都不一样——这时候,你写的那套llm.invoke()调用链,是不是瞬间变成了一堆需要逐行重写的“技术债”?这正是 LangChain 社区在 Part 29 中正式提出Model Agnostic Pattern(模型无关模式)的真实战场。它不是教你怎么调 API,而是教你如何把“调哪个模型”这件事,从代码里彻底抽离出来,变成一个可配置、可热替换、可灰度发布、甚至可按用户 ID 动态路由的运行时决策。而LLM API Gateway(大模型 API 网关),就是支撑这个模式落地的基础设施底座——它不生产模型,但决定谁来响应、怎么响应、响应失败后怎么兜底。我带过三个 AI 应用交付项目,其中两个卡在模型切换环节超过三周,不是因为模型能力不行,而是因为整个调用链路像一捆缠死的电线,剪哪根都断服务。这篇内容,就是我把这捆电线一根根剥开、标上色、装上快插接头的过程。它适合所有正在用 LangChain 构建生产级 AI Agent 的人,尤其是那些已经踩过“硬编码模型名”“手动转换消息格式”“重写 retry 逻辑”这些坑的工程师和架构师。如果你还在ChatOpenAI(model="gpt-4-turbo")里写死模型名,那你不是在写应用,是在给未来埋雷。
2. 核心设计思路拆解:从“硬编码调用”到“协议抽象层”的四步跃迁
2.1 为什么传统 LangChain LLM 封装方式注定不可扩展
LangChain 早期的ChatOpenAI、ChatAnthropic、ChatQwen这类封装器,表面看是“开箱即用”,实则暗藏三重耦合陷阱:
协议耦合:
ChatOpenAI强依赖 OpenAI 的/v1/chat/completions接口规范,包括messages字段必须是{"role": "user", "content": "xxx"}结构,response_format必须是 JSON Schema,tool_choice必须是"auto"或{"type": "function", "function": {"name": "xxx"}}。而 Qwen 的tools字段叫functions,DeepSeek 的tool_choice叫enable_thinking,千问的流式响应是delta.content,而 GLM 的是choices[0].delta.text。一旦换模型,光是消息序列的to_messages()方法就得重写。行为耦合:
ChatOpenAI的stream=True返回的是AsyncIterator[ChatGenerationChunk],而本地 Ollama 模型返回的是Iterator[dict],类型系统完全断裂。更麻烦的是 retry 逻辑——OpenAI 的429错误要指数退避,而某国产模型的429是永久性配额超限,retry 百次也没用;某云厂商的503表示模型实例未就绪,需要轮询健康检查端点,而不是简单 sleep。这些差异被封装器粗暴地“统一”成BaseLLM.generate(),结果就是线上报错时,你得翻三份文档才能定位问题。生命周期耦合:
ChatOpenAI实例初始化时就绑定了api_key、base_url、model_name,这意味着你无法在同一个 Agent 流程中,对不同子任务动态切换模型——比如“法律条款解析”用高精度 72B 模型,“用户情绪判断”用低延迟 7B 模型,“多语言翻译”用专精小模型。你只能写三套几乎一样的 Chain,维护成本翻三倍。
我去年在一个跨境电商业务中就栽在这上面:最初用ChatOpenAI(model="gpt-3.5-turbo")做商品描述生成,QPS 上去后成本飙升,想切到ChatQwen(model="qwen2.5-7b"),结果发现qwen2.5-7b不支持response_format={"type": "json_object"},JSON 输出全靠 prompt engineering + post-process 正则提取,准确率从 99.2% 掉到 83.7%。最后不是换模型,而是给 Qwen 单独写了个QwenJSONOutputParser,又额外加了 200 行校验逻辑。这就是“协议耦合”带来的典型熵增。
2.2 Model Agnostic Pattern 的四层抽象设计
LangGraph Part 29 提出的 Model Agnostic Pattern,本质是构建一个四层抽象漏斗,把模型差异全部拦在最外层:
第 1 层:统一输入契约(Unified Input Contract)
定义一个与任何模型无关的AIAgentInput数据类,只包含业务语义字段:user_query: str、context: Dict[str, Any]、required_tools: List[str]、output_format: Literal["text", "json", "structured"]。所有上游 Agent 节点(如 Router、Retriever)只跟这个契约打交道,绝不碰messages、tools这些协议层字段。第 2 层:协议适配器(Protocol Adapter)
每个模型供应商对应一个 Adapter 类,例如OpenAIAdapter、QwenAdapter、DeepSeekAdapter。它的唯一职责,是把AIAgentInput翻译成该模型能理解的原始请求体,并把原始响应翻译回统一的AIAgentOutput(含text: str、structured_data: Optional[Dict]、tool_calls: List[ToolCall])。Adapter 内部封装所有协议细节:消息格式转换、tool call 解析、流式 chunk 合并、错误码映射(如把 Qwen 的50001映射为标准LLMRateLimitError)。第 3 层:模型路由网关(Model Routing Gateway)
这就是 LLM API Gateway 的核心。它接收标准化的AIAgentInput,根据预设策略(如model_policy: "by_user_tier"、"by_latency_sla"、"by_cost_budget")查询路由规则引擎,决定调用哪个 Adapter。规则引擎本身是可插拔的——可以是内存中的Dict配置,也可以是连接 Redis 的实时策略中心,甚至可以是调用另一个轻量级 LLM 做动态决策(比如用 7B 模型分析 query 复杂度,再决定是否升配到 72B)。第 4 层:运行时执行器(Runtime Executor)
一个薄薄的ModelExecutor类,只做三件事:1)根据路由结果实例化对应 Adapter;2)调用 Adapter 的invoke()方法;3)捕获所有底层异常,统一包装为AIAgentError并附带adapter_name、original_error等上下文。它不关心模型怎么工作,只确保“输入进来,输出出去,出错了有迹可循”。
这个设计最妙的地方在于,当你新增一个模型(比如刚发布的 Kimi 2.0),你只需要写一个新的KimiAdapter,实现adapt_input()和adapt_output()两个方法,然后在网关配置里加一条路由规则,整个系统无需重启、无需改任何业务代码,就能接入新模型。我在上个月就用这套模式,在 2 小时内把一个已上线的合同审查 Agent,从gpt-4o切换到了kimi-pro,全程零 downtime,用户无感知。
2.3 为什么必须用 LangGraph 而非纯 LangChain 实现
很多人会问:LangChain 本身就有BaseLLM抽象,为啥还要 LangGraph?关键在于状态持久化和节点可组合性。LangChain 的 Chain 是线性的、一次性的,LLMChain执行完就销毁。而 LangGraph 的StateGraph是有状态的图结构,每个节点(Node)可以读写共享状态State。Model Agnostic Pattern 的核心价值,恰恰体现在“状态驱动的动态路由”上:
- 你可以让 Router 节点根据
State["query_complexity"]字段,决定调用high_precision_adapter还是low_latency_adapter; - 你可以让 Fallback 节点监听
State["last_adapter_failure"],当OpenAIAdapter连续失败 3 次,自动触发QwenAdapter降级; - 你可以让 Metrics 节点在每次
ModelExecutor执行前后,记录adapter_name、input_tokens、output_tokens、latency_ms到 Prometheus,形成模型级的可观测性大盘。
LangChain 的Runnable也能做类似事,但它缺乏原生的状态管理能力。你得自己用functools.partial绑定上下文,或者用threading.local()存状态,一到异步环境就崩。而 LangGraph 的State是深度集成的,State.update()是原子操作,State.get()是线程安全的,这才是生产环境需要的确定性。我见过太多团队用 LangChain + 自研状态管理,结果在高并发下出现状态污染,查 bug 查了三天,最后发现是global变量被多个协程同时修改。LangGraph 把这个坑直接填平了。
3. 核心组件实现详解:从协议适配器到网关路由的完整代码实录
3.1 统一输入/输出契约定义:用 Pydantic V2 做强类型保障
我们先定义最顶层的契约。注意,这里不用dataclass,而用 Pydantic V2 的BaseModel,因为它自带验证、序列化、文档生成能力,且与 FastAPI、LangServe 天然兼容:
from typing import List, Dict, Any, Optional, Literal, Union from pydantic import BaseModel, Field, field_validator import re class ToolCall(BaseModel): """统一工具调用表示,屏蔽各模型 tool_call 字段差异""" name: str = Field(..., description="工具名称") arguments: Dict[str, Any] = Field(default_factory=dict, description="工具参数字典") id: Optional[str] = Field(None, description="调用ID,用于响应关联") class AIAgentOutput(BaseModel): """模型输出的统一表示""" text: str = Field("", description="纯文本输出") structured_data: Optional[Dict[str, Any]] = Field( None, description="结构化数据,如JSON解析结果" ) tool_calls: List[ToolCall] = Field(default_factory=list, description="工具调用列表") metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据,如token用量、延迟") @field_validator("structured_data") @classmethod def validate_structured_data(cls, v): if v is not None and not isinstance(v, dict): raise ValueError("structured_data must be a dict") return v class AIAgentInput(BaseModel): """统一输入契约,业务语义优先""" user_query: str = Field(..., description="用户原始查询") context: Dict[str, Any] = Field(default_factory=dict, description="上下文信息") required_tools: List[str] = Field(default_factory=list, description="必需的工具列表") output_format: Literal["text", "json", "structured"] = Field( "text", description="期望输出格式" ) max_tokens: int = Field(2048, ge=1, le=32768, description="最大生成token数") temperature: float = Field(0.3, ge=0.0, le=2.0, description="采样温度") # 注意:这里没有 model_name、api_key、base_url!它们属于网关配置 @field_validator("user_query") @classmethod def validate_query_length(cls, v): if len(v.strip()) == 0: raise ValueError("user_query cannot be empty or whitespace only") if len(v) > 128000: # 限制单次输入长度,防OOM raise ValueError("user_query too long, max 128000 chars") return v这个契约的设计哲学是:只暴露业务需要的字段,隐藏所有技术细节。user_query是用户说了什么,required_tools是业务逻辑决定要调哪些工具,output_format是业务目标(要 JSON 还是纯文本)。至于“用哪个模型”、“走哪个 endpoint”、“key 怎么鉴权”,统统不在这里出现。我坚持这个原则,是因为在三个项目中,业务方提需求时从来不会说“我要调 Qwen 的 /v1/chat/completions”,他们只会说“这个合同条款要解析成 JSON 格式”。契约必须反映业务语言,而不是技术语言。
3.2 协议适配器实现:以 OpenAI 和 Qwen 为例的双向翻译
适配器的核心是adapt_input()和adapt_output()两个方法。我们以OpenAIAdapter为例,展示如何把AIAgentInput翻译成 OpenAI 的标准请求体:
import json from openai import AsyncOpenAI from openai.types.chat import ChatCompletionMessageParam from typing import List, Dict, Any, Optional, cast class OpenAIAdapter: def __init__(self, api_key: str, base_url: str = "https://api.openai.com/v1"): self.client = AsyncOpenAI(api_key=api_key, base_url=base_url) self.model_name = "gpt-4o-mini" # 默认模型,可覆盖 def adapt_input(self, input_data: AIAgentInput) -> Dict[str, Any]: """将 AIAgentInput 翻译为 OpenAI API 请求体""" # 构建 messages:系统提示 + 用户查询 + 上下文 messages: List[ChatCompletionMessageParam] = [ {"role": "system", "content": "You are a helpful AI assistant."} ] # 添加上下文(如果存在) if input_data.context: context_str = "\n".join([ f"{k}: {v}" for k, v in input_data.context.items() ]) messages.append({"role": "system", "content": f"Context:\n{context_str}"}) # 添加用户查询 messages.append({"role": "user", "content": input_data.user_query}) # 构建 tools(如果需要) tools = [] if input_data.required_tools: # 这里假设你有一个 tools registry,根据 name 查找 tool spec from my_tools_registry import TOOLS_REGISTRY for tool_name in input_data.required_tools: if tool_name in TOOLS_REGISTRY: tools.append(TOOLS_REGISTRY[tool_name]) # 构建请求参数 request_body = { "model": self.model_name, "messages": messages, "max_tokens": input_data.max_tokens, "temperature": input_data.temperature, } # 根据 output_format 添加 response_format 或 tool_choice if input_data.output_format == "json": request_body["response_format"] = {"type": "json_object"} elif input_data.required_tools: # 如果有工具且没指定 json,则默认 auto request_body["tool_choice"] = "auto" request_body["tools"] = tools return request_body async def adapt_output(self, raw_response: Any) -> AIAgentOutput: """将 OpenAI 原始响应翻译为 AIAgentOutput""" from openai.types.chat import ChatCompletion completion = cast(ChatCompletion, raw_response) # 提取文本 text = "" if completion.choices and completion.choices[0].message.content: text = completion.choices[0].message.content # 提取结构化数据(JSON) structured_data = None if input_data.output_format == "json" and text: try: structured_data = json.loads(text) except json.JSONDecodeError: pass # 交给上层处理 # 提取 tool calls tool_calls = [] if completion.choices and completion.choices[0].message.tool_calls: for tc in completion.choices[0].message.tool_calls: tool_calls.append(ToolCall( name=tc.function.name, arguments=json.loads(tc.function.arguments), id=tc.id )) # 构建元数据 metadata = { "input_tokens": completion.usage.prompt_tokens if completion.usage else 0, "output_tokens": completion.usage.completion_tokens if completion.usage else 0, "total_tokens": completion.usage.total_tokens if completion.usage else 0, "model": completion.model, "finish_reason": completion.choices[0].finish_reason if completion.choices else None } return AIAgentOutput( text=text, structured_data=structured_data, tool_calls=tool_calls, metadata=metadata )现在看QwenAdapter,重点对比它如何处理 Qwen 特有的差异:
from dashscope import Generation # Qwen SDK from dashscope.common.error import RequestFailed class QwenAdapter: def __init__(self, api_key: str, model_name: str = "qwen2.5-7b-instruct"): self.api_key = api_key self.model_name = model_name def adapt_input(self, input_data: AIAgentInput) -> Dict[str, Any]: """Qwen 的 messages 格式是 [{"role": "user", "content": "..."}],但 tools 字段叫 functions""" messages = [{"role": "user", "content": input_data.user_query}] # Qwen 不支持 system role,所以把 system 提示和 context 合并到第一个 user message system_prompt = "You are a helpful AI assistant." if input_data.context: context_str = "\n".join([ f"{k}: {v}" for k, v in input_data.context.items() ]) system_prompt += f"\nContext:\n{context_str}" # Qwen 的第一消息必须是 user,所以把 system prompt 也塞进去 messages[0]["content"] = f"{system_prompt}\n\nUser query:\n{input_data.user_query}" # Qwen 的 tools 字段叫 functions,且格式是 [{"name": "...", "description": "...", "parameters": {...}}] functions = [] if input_data.required_tools: from my_tools_registry import TOOLS_REGISTRY for tool_name in input_data.required_tools: if tool_name in TOOLS_REGISTRY: # Qwen 的 function spec 需要转换 qwen_func = { "name": TOOLS_REGISTRY[tool_name]["name"], "description": TOOLS_REGISTRY[tool_name]["description"], "parameters": TOOLS_REGISTRY[tool_name]["parameters"] } functions.append(qwen_func) request_body = { "model": self.model_name, "input": {"messages": messages}, "parameters": { "max_tokens": input_data.max_tokens, "temperature": input_data.temperature, } } if functions: request_body["input"]["functions"] = functions # Qwen 的 function_call 是 "auto" 或 {"name": "xxx"} request_body["parameters"]["function_call"] = "auto" return request_body async def adapt_output(self, raw_response: Any) -> AIAgentOutput: """Qwen 响应格式:{"output": {"text": "...", "choices": [...]}}""" try: # DashScope SDK 的响应结构 text = raw_response.output.text if raw_response.output.text else "" structured_data = None if input_data.output_format == "json" and text: try: structured_data = json.loads(text) except json.JSONDecodeError: pass tool_calls = [] # Qwen 的 tool call 在 output.choices[0].message.function_call if (hasattr(raw_response, 'output') and hasattr(raw_response.output, 'choices') and raw_response.output.choices): choice = raw_response.output.choices[0] if (hasattr(choice.message, 'function_call') and choice.message.function_call): fc = choice.message.function_call tool_calls.append(ToolCall( name=fc.name, arguments=json.loads(fc.arguments) if fc.arguments else {}, id=None # Qwen 不返回 id )) metadata = { "input_tokens": raw_response.usage.input_tokens if hasattr(raw_response, 'usage') else 0, "output_tokens": raw_response.usage.output_tokens if hasattr(raw_response, 'usage') else 0, "model": self.model_name, "request_id": raw_response.request_id if hasattr(raw_response, 'request_id') else None } return AIAgentOutput( text=text, structured_data=structured_data, tool_calls=tool_calls, metadata=metadata ) except Exception as e: # Qwen 的异常类型是 RequestFailed,需统一包装 raise AIAgentError(f"QwenAdapter error: {str(e)}", adapter_name="qwen")提示:适配器的
adapt_input()方法里,我刻意把context合并到user_query,是因为 Qwen 的systemrole 支持不稳定,很多版本会忽略。这是实战中踩出来的坑——不要迷信文档,要以实测为准。我在测试 Qwen2.5-72B 时发现,加了systemrole 反而让模型更倾向于复述 system 提示,而不是回答问题。所以适配器不是机械翻译,而是“智能桥接”。
3.3 LLM API Gateway:基于策略的动态路由引擎
网关是整个模式的大脑。我们用一个简单的ModelGateway类实现,它支持三种路由策略:
from enum import Enum from typing import Dict, Any, Callable, Optional, Awaitable import asyncio import random class RoutingStrategy(str, Enum): STATIC = "static" # 固定模型 ROUND_ROBIN = "round_robin" # 轮询 WEIGHTED = "weighted" # 加权 CONTEXT_AWARE = "context_aware" # 上下文感知(需外部 LLM) class ModelGateway: def __init__(self): # 注册所有可用的 Adapter self.adapters: Dict[str, Callable[[], Any]] = {} # 路由规则:strategy -> config self.routing_rules: Dict[str, Dict[str, Any]] = {} # 轮询计数器 self.rr_counter = 0 def register_adapter(self, name: str, factory: Callable[[], Any]): """注册 Adapter 工厂函数,延迟实例化""" self.adapters[name] = factory def set_routing_rule(self, strategy: RoutingStrategy, config: Dict[str, Any]): """设置路由策略""" self.routing_rules[strategy.value] = config async def route_to_adapter(self, input_data: AIAgentInput) -> Any: """根据策略选择 Adapter 实例""" strategy = self.routing_rules.get("strategy", "static") if strategy == "static": model_name = self.routing_rules.get("static", "openai") if model_name not in self.adapters: raise ValueError(f"Adapter {model_name} not registered") return self.adapters[model_name]() elif strategy == "round_robin": available_models = list(self.adapters.keys()) if not available_models: raise ValueError("No adapters registered for round-robin") chosen_model = available_models[self.rr_counter % len(available_models)] self.rr_counter += 1 return self.adapters[chosen_model]() elif strategy == "weighted": weights = self.routing_rules.get("weights", {}) if not weights: raise ValueError("Weights config missing for weighted routing") models = list(weights.keys()) weights_list = [weights[m] for m in models] chosen_model = random.choices(models, weights=weights_list)[0] return self.adapters[chosen_model]() elif strategy == "context_aware": # 这里可以调用一个轻量级 LLM 来分析 input_data.context # 例如:用 7B 模型判断 query 是否涉及法律术语,决定是否升配 return await self._context_aware_route(input_data) else: raise ValueError(f"Unknown routing strategy: {strategy}") async def _context_aware_route(self, input_data: AIAgentInput) -> Any: """上下文感知路由:用小模型分析 query 复杂度""" # 简化版:关键词匹配 query_lower = input_data.user_query.lower() if any(word in query_lower for word in ["contract", "clause", "legal", "compliance"]): return self.adapters.get("deepseek", self.adapters["openai"])() elif any(word in query_lower for word in ["translate", "language", "中文"]): return self.adapters.get("qwen", self.adapters["openai"])() else: return self.adapters["openai"]() # 全局网关实例 GATEWAY = ModelGateway() # 注册适配器(实际项目中应从配置中心加载) GATEWAY.register_adapter("openai", lambda: OpenAIAdapter( api_key="sk-...", base_url="https://api.openai.com/v1" )) GATEWAY.register_adapter("qwen", lambda: QwenAdapter( api_key="YOUR_QWEN_KEY", model_name="qwen2.5-7b-instruct" )) GATEWAY.register_adapter("deepseek", lambda: DeepSeekAdapter( api_key="YOUR_DEEPSEEK_KEY" )) # 设置路由策略:默认静态,生产环境可切为 weighted GATEWAY.set_routing_rule(RoutingStrategy.STATIC, {"static": "openai"}) # GATEWAY.set_routing_rule(RoutingStrategy.WEIGHTED, {"weights": {"openai": 0.6, "qwen": 0.4}})注意:
register_adapter接收的是工厂函数lambda: Adapter(),而不是 Adapter 实例。这是因为 Adapter 可能包含网络连接、认证状态等有状态资源,必须按需创建,避免多线程/协程间状态污染。我在一个高并发客服系统中就吃过亏:把OpenAIAdapter实例存为全局变量,结果api_key被多个请求并发修改,导致一半请求 401。用工厂函数,每次route_to_adapter()都新建干净实例,彻底规避这个问题。
3.4 运行时执行器:统一异常处理与可观测性注入
最后是ModelExecutor,它把适配器、网关、重试、监控串起来:
import time import logging from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type from typing import Type, Dict, Any class AIAgentError(Exception): """统一 Agent 错误基类""" def __init__(self, message: str, adapter_name: str = "", original_error: Exception = None): super().__init__(message) self.adapter_name = adapter_name self.original_error = original_error class ModelExecutor: def __init__(self, gateway: ModelGateway): self.gateway = gateway self.logger = logging.getLogger(__name__) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10), retry=retry_if_exception_type((ConnectionError, TimeoutError, AIAgentError)) ) async def execute(self, input_data: AIAgentInput) -> AIAgentOutput: """执行模型调用,含重试、监控、错误统一""" start_time = time.time() adapter = None try: # 1. 路由获取 Adapter adapter = await self.gateway.route_to_adapter(input_data) # 2. 适配输入 request_body = adapter.adapt_input(input_data) # 3. 调用模型(Adapter 内部实现) raw_response = await adapter.invoke(request_body) # Adapter 需实现 invoke 方法 # 4. 适配输出 output = await adapter.adapt_output(raw_response) # 5. 注入耗时元数据 latency_ms = int((time.time() - start_time) * 1000) output.metadata["latency_ms"] = latency_ms output.metadata["adapter_name"] = adapter.__class__.__name__.replace("Adapter", "") self.logger.info( f"Model execution success: adapter={output.metadata['adapter_name']} " f"latency={latency_ms}ms tokens={output.metadata.get('total_tokens', 0)}" ) return output except Exception as e: latency_ms = int((time.time() - start_time) * 1000) error_msg = f"Model execution failed after {latency_ms}ms: {str(e)}" # 包装为统一错误 if not isinstance(e, AIAgentError): wrapped_error = AIAgentError( error_msg, adapter_name=adapter.__class__.__name__.replace("Adapter", "") if adapter else "unknown", original_error=e ) raise wrapped_error else: raise e # 已是统一错误,直接抛出 # 初始化执行器 EXECUTOR = ModelExecutor(GATEWAY)这个ModelExecutor的价值在于:它把所有“脏活累活”都包圆了。你作为业务开发者,只需要调EXECUTOR.execute(input_data),就能得到一个干净的AIAgentOutput,里面包含了文本、结构化数据、工具调用、token 用量、延迟、模型名——所有你需要的信息,而且格式统一。再也不用在每个 Chain 里写try...except,不用手动算time.time(),不用纠结openai.RateLimitError和dashscope.RequestFailed怎么 catch。这就是抽象的价值。
4. LangGraph 集成:将 Model Agnostic Pattern 编排进有状态的 Agent 图
4.1 定义可复用的 State 和 Node
LangGraph 的威力,在于把ModelExecutor封装成一个可复用的 Node。我们定义一个通用的llm_node:
from langgraph.graph import StateGraph, END from typing import TypedDict, Annotated, Sequence import operator class GraphState(TypedDict): """LangGraph 共享状态""" input_data: AIAgentInput llm_output: Optional[AIAgentOutput] = None error: Optional[str] = None retry_count: int = 0 # 可以添加更多字段,如 retrieved_docs, router_decision 等 # 定义一个通用的 LLM 调用节点 async def llm_node(state: GraphState) -> GraphState: """调用 LLM 的通用节点""" try: # 从 state 中取出 input_data input_data = state["input_data"] # 执行模型调用 output = await EXECUTOR.execute(input_data) return { "llm_output": output, "error": None, "retry_count": state.get("retry_count", 0) } except AIAgentError as e: # 记录错误,但不中断流程,留给 fallback 节点处理 error_msg = f"LLM call failed: {e.adapter_name} - {str(e)}" return { "error": error_msg, "llm_output": None, "retry_count": state.get("retry_count", 0) + 1 } # 定义 fallback 节点:当主 LLM 失败时,降级到备用模型 async def fallback_node(state: GraphState) -> GraphState: """降级节点:切换到备用 Adapter""" if not state.get("error"): return state # 简单降级:从 openai 切到 qwen current_adapter = "openai" fallback_adapter = "qwen" # 修改 input_data 的上下文,标记为降级调用 new_input = state["input_data"].model_copy(update={ "context": { **state["input_data"].context, "fallback_triggered": True, "original_error": state["error"] } }) # 临时覆盖网关策略 original_strategy = GATEWAY.routing_rules.get("strategy") GATEWAY.set_routing_rule(RoutingStrategy.STATIC, {"static": fallback_adapter}) try: output = await EXECUTOR.execute(new_input) return { "llm_output": output, "error": None, "retry_count": state.get("retry_count", 0) } except Exception as e: # 降级也失败,返回原始错误 return { "error": f"Fallback to {fallback_adapter} also failed: {str(e)}", "llm_output": None, "retry_count": state.get("retry_count", 0) + 1 } finally: # 恢复原始策略 if original_strategy: GATEWAY.set_routing_rule(RoutingStrategy.STATIC, {"static": current_adapter}) # 定义 router 节点:根据 query 复杂度决定模型 async def router_node(state: GraphState) -> str: """路由节点:返回下一个节点名""" input_data = state["input_data"] # 简单规则:字符数 > 5000 且含法律术语,走 high_precision if (len(input_data.user_query) > 5000 and any(term in input_data.user_query.lower() for term in ["contract", "clause", "section"])): return "high_precision_llm" else: return "default_llm"4.2 构建完整的 Agent 图:支持动态路由、降级、重试
现在,我们把所有节点编排成一个有状态的图:
# 创建图 workflow = StateGraph(GraphState) # 添加节点 workflow.add_node("router", router_node) workflow.add_node("default_llm", llm_node) workflow.add_node("high_precision_llm", llm_node) workflow.add_node("fallback", fallback_node) workflow.add_node("end", lambda state: state) # 终止节点 # 设置边 workflow.set_conditional_entry_point( router_node, { "default_llm": "default_llm", "high_precision_llm": "high_precision_llm" } ) # default_llm 成功则到 end,失败