代码
import json import re from langchain_core.tools import StructuredTool, Tool from langchain_openai import ChatOpenAI from pydantic import ValidationError from lfx.base.agents.agent import LCToolsAgentComponent from lfx.base.agents.events import ExceptionWithMessageError from lfx.base.models.model_input_constants import ( ALL_PROVIDER_FIELDS, MODEL_DYNAMIC_UPDATE_FIELDS, MODEL_PROVIDERS_DICT, MODEL_PROVIDERS_LIST, MODELS_METADATA, ) from lfx.base.models.model_utils import get_model_name from lfx.components.helpers import CurrentDateComponent from lfx.components.langchain_utilities.tool_calling import ToolCallingAgentComponent from lfx.components.models_and_agents.memory import MemoryComponent from lfx.custom.custom_component.component import get_component_toolkit from lfx.custom.utils import update_component_build_config from lfx.helpers.base_model import build_model_from_schema from lfx.inputs.inputs import BoolInput, SecretStrInput, StrInput from lfx.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output, TableInput from lfx.log.logger import logger from lfx.schema.data import Data from lfx.schema.dotdict import dotdict from lfx.schema.message import Message from lfx.schema.table import EditMode def set_advanced_true(component_input): component_input.advanced = True return component_input class AgentComponent(ToolCallingAgentComponent): display_name: str = "Agent" description: str = "Define the agent's instructions, then enter a task to complete using tools." documentation: str = "https://docs.langflow.org/agents" icon = "bot" beta = False name = "Agent" memory_inputs = [set_advanced_true(component_input) for component_input in MemoryComponent().inputs] if "OpenAI" in MODEL_PROVIDERS_DICT: openai_inputs_filtered = [ input_field for input_field in MODEL_PROVIDERS_DICT["OpenAI"]["inputs"] if not (hasattr(input_field, "name") and input_field.name == "json_mode") ] else: openai_inputs_filtered = [] inputs = [ DropdownInput( name="agent_llm", display_name="Model Provider", info="选择模型服务", options=["Custom OpenAI Compatible"], value="Custom OpenAI Compatible", real_time_refresh=True, refresh_button=False, input_types=[], ), SecretStrInput( name="api_key", display_name="API Key", info="第三方模型 API Key", required=True, ), StrInput( name="base_url", display_name="Base URL", info="第三方模型接口地址 https://xxx/v1", required=True, show=True, ), StrInput( name="model_name", display_name="Model Name", info="模型名称 如 qwen-turbo / deepseek-chat", required=True, show=True, ), IntInput( name="max_output_tokens", display_name="Max Output Tokens", info="最大生成token", show=True, ), MultilineInput( name="system_prompt", display_name="Agent Instructions", value="You are a helpful assistant that can use tools to answer questions and perform tasks.", advanced=False, ), MessageTextInput( name="context_id", display_name="Context ID", value="", advanced=True, ), IntInput( name="n_messages", display_name="Number of Chat History Messages", value=100, advanced=True, show=True, ), MultilineInput( name="format_instructions", display_name="Output Format Instructions", value=( "You are an AI that extracts structured JSON objects from unstructured text. " "Use a predefined schema with expected types (str, int, float, bool, dict). " "Extract ALL relevant instances that match the schema - if multiple patterns exist, capture them all. " "Fill missing or ambiguous values with defaults: null for missing values. " "Remove exact duplicates but keep variations that have different field values. " "Always return valid JSON in the expected format, never throw errors. " "If multiple objects can be extracted, return them all in the structured format." ), advanced=True, ), TableInput( name="output_schema", display_name="Output Schema", advanced=True, required=False, value=[], table_schema=[ { "name": "name", "display_name": "Name", "type": "str", "default": "field", "edit_mode": EditMode.INLINE, }, { "name": "description", "display_name": "Description", "type": "str", "default": "description", "edit_mode": EditMode.POPOVER, }, { "name": "type", "display_name": "Type", "type": "str", "options": ["str", "int", "float", "bool", "dict"], "default": "str", }, { "name": "multiple", "display_name": "As List", "type": "boolean", "default": "False", }, ], ), *LCToolsAgentComponent.get_base_inputs(), BoolInput( name="add_current_date_tool", display_name="Current Date", advanced=True, value=True, ), ] outputs = [ Output(name="response", display_name="Response", method="message_response"), ] async def get_agent_requirements(self): llm_model, display_name = await self.get_llm() if llm_model is None: raise ValueError("No language model selected.") self.model_name = get_model_name(llm_model, display_name=display_name) self.chat_history = await self.get_memory_data() await logger.adebug(f"Retrieved {len(self.chat_history)} chat history messages") if isinstance(self.chat_history, Message): self.chat_history = [self.chat_history] if self.add_current_date_tool: if not isinstance(self.tools, list): self.tools = [] current_date_tool = (await CurrentDateComponent(**self.get_base_args()).to_toolkit()).pop(0) self.tools.append(current_date_tool) self.set_tools_callbacks(self.tools, self._get_shared_callbacks()) return llm_model, self.chat_history, self.tools async def message_response(self) -> Message: try: llm_model, self.chat_history, self.tools = await self.get_agent_requirements() self.set( llm=llm_model, tools=self.tools or [], chat_history=self.chat_history, input_value=self.input_value, system_prompt=self.system_prompt, ) agent = self.create_agent_runnable() result = await self.run_agent(agent) self._agent_result = result return result except Exception as e: await logger.aerror(f"Error: {e!s}") raise def _preprocess_schema(self, schema): processed_schema = [] for field in schema: processed_field = { "name": str(field.get("name", "field")), "type": str(field.get("type", "str")), "description": str(field.get("description", "")), "multiple": field.get("multiple", False), } if isinstance(processed_field["multiple"], str): processed_field["multiple"] = processed_field["multiple"].lower() in ["true", "1", "t", "y", "yes"] processed_schema.append(processed_field) return processed_schema async def build_structured_output_base(self, content: str): json_pattern = r"\{.*\}" schema_error_msg = "Try setting an output schema" json_data = None try: json_data = json.loads(content) except json.JSONDecodeError: json_match = re.search(json_pattern, content, re.DOTALL) if json_match: try: json_data = json.loads(json_match.group()) except json.JSONDecodeError: return {"content": content, "error": schema_error_msg} else: return {"content": content, "error": schema_error_msg} if not self.output_schema: return json_data try: processed_schema = self._preprocess_schema(self.output_schema) output_model = build_model_from_schema(processed_schema) if isinstance(json_data, list): validated = [] for item in json_data: try: v = output_model.model_validate(item) validated.append(v.model_dump()) except ValidationError as e: validated.append({"data": item, "error": str(e)}) return validated v = output_model.model_validate(json_data) return [v.model_dump()] except Exception as e: await logger.aerror(f"Schema error: {e}") return json_data async def json_response(self) -> Data: try: system_components = [] if self.system_prompt: system_components.append(self.system_prompt) if self.format_instructions: system_components.append(f"Format: {self.format_instructions}") combined = "\n\n".join(system_components) llm_model, self.chat_history, self.tools = await self.get_agent_requirements() self.set(llm=llm_model, tools=self.tools, chat_history=self.chat_history, input_value=self.input_value, system_prompt=combined) agent = self.create_agent_runnable() result = await self.run_agent(agent) content = result.content if hasattr(result, "content") else str(result) output = await self.build_structured_output_base(content) if isinstance(output, list) and len(output) == 1: return Data(data=output[0]) return Data(data=output if output else {"content": content}) except Exception as e: return Data(data={"error": str(e)}) async def get_memory_data(self): messages = await MemoryComponent(**self.get_base_args()).set( session_id=self.graph.session_id, context_id=self.context_id, order="Ascending", n_messages=self.n_messages, ).retrieve_messages() return [m for m in messages if getattr(m, "id", None) != getattr(self.input_value, "id", None)] # ====================== ✅ 核心修改:支持任意第三方模型 ====================== async def get_llm(self): try: # 强制使用 OpenAI 兼容协议 from langchain_openai import ChatOpenAI llm = ChatOpenAI( api_key=self.api_key, base_url=self.base_url, model=self.model_name, max_tokens=self.max_output_tokens or 1024, ) return llm, "Custom OpenAI Compatible" except Exception as e: await logger.aerror(f"LLM init failed: {e!s}") raise ValueError(f"无法初始化模型: {e}") from e def _build_llm_model(self, component, inputs, prefix=""): # 空实现,兼容原有框架 return None # ====================== 其他方法保持不变,精简版 ====================== def set_component_params(self, component): return component def delete_fields(self, build_config, fields): pass def update_input_types(self, build_config): return build_config async def update_build_config(self, build_config, field_value, field_name=None): return build_config async def _get_tools(self) -> list[Tool]: return []