function initial_guess = enhancedInitialGuess(x, y, model_type, param_names)
% 增强型智能初始值估计
fprintf('增强型智能初始值估计:\n');
switch model_type
case '二次函数(y=a*x^2+b*x+c)'
% 使用稳健的多项式拟合
try
p = robustfit([x(:).^2, x(:)], y(:));
initial_guess = [p(2), p(3), p(1)]; % [a, b, c]
catch
% 如果稳健拟合失败,使用普通多项式拟合
p = polyfit(x, y, 2);
initial_guess = [p(1), p(2), p(3)];
end
case '指数函数(y=a*exp(b*x))'
% 更稳健的指数函数初始值估计
y_positive = y(y > 0);
if isempty(y_positive)
initial_guess = [mean(y), -0.01];
else
% 确保所有y值都为正
y_min = min(y);
if y_min <= 0
y_adj = y - y_min + 0.01 * (max(y) - min(y));
else
y_adj = y;
end
% 对数线性拟合
try
log_y = log(y_adj);
p = polyfit(x, log_y, 1);
initial_guess = [exp(p(2)), p(1)];
catch
initial_guess = [max(y)/2, -0.01];
end
end
otherwise
% 通用初始值策略
initial_guess = ones(1, length(param_names));
initial_guess(1) = mean(y);
if length(initial_guess) > 1
initial_guess(2) = 0.1 * (max(y) - min(y)) / (max(x) - min(x));
end
if length(initial_guess) > 2
initial_guess(3) = min(y);
end
end
% 显示估计结果
for i = 1:length(param_names)
fprintf(' %s: %.6g\n', param_names{i}, initial_guess(i));
end
end
function [x_processed, y_processed, scaling_factors] = preprocessData(x, y)
% 数据预处理:去除异常值、标准化等
% 1. 去除异常值(使用3σ原则)
y_mean = mean(y);
y_std = std(y);
if y_std > 0
valid_idx = abs(y - y_mean) < 3 * y_std;
x_clean = x(valid_idx);
y_clean = y(valid_idx);
if length(x_clean) < 3
% 如果去除太多点,恢复原始数据
x_clean = x;
y_clean = y;
fprintf('警告: 数据点过少,跳过异常值去除\n');
else
fprintf('数据预处理: 去除 %d 个异常值\n', length(x) - length(x_clean));
end
else
x_clean = x;
y_clean = y;
end
% 2. 数据标准化(避免数值问题)
x_mean = mean(x_clean);
x_std = std(x_clean);
y_mean = mean(y_clean);
y_std = std(y_clean);
% 避免除零
if x_std == 0, x_std = 1; end
if y_std == 0, y_std = 1; end
x_processed = (x_clean - x_mean) / x_std;
y_processed = (y_clean - y_mean) / y_std;
scaling_factors = struct('x_mean', x_mean, 'x_std', x_std, ...
'y_mean', y_mean, 'y_std', y_std);
end
function [best_params, best_gof, best_func] = multiAlgorithmFit(model_func, x, y, init_guess, param_names)
% 多算法拟合策略
algorithms = {'lsqcurvefit', 'fminsearch', 'patternsearch'};
results = cell(length(algorithms), 3);
for i = 1:length(algorithms)
try
switch algorithms{i}
case 'lsqcurvefit'
options = optimoptions('lsqcurvefit', 'Display', 'off', ...
'MaxFunctionEvaluations', 5000, 'FunctionTolerance', 1e-10, ...
'OptimalityTolerance', 1e-10, 'StepTolerance', 1e-10);
[params, ~, residual, exitflag] = lsqcurvefit(model_func, init_guess, x, y, [], [], options);
if exitflag <= 0
error('优化失败');
end
case 'fminsearch'
obj_func = @(p) sum((model_func(p, x) - y).^2);
options = optimset('Display', 'off', 'MaxFunEvals', 5000, 'TolFun', 1e-10, 'TolX', 1e-10);
[params, ~, exitflag] = fminsearch(obj_func, init_guess, options);
residual = obj_func(params);
if exitflag <= 0
error('优化失败');
end
case 'patternsearch'
obj_func = @(p) sum((model_func(p, x) - y).^2);
options = optimoptions('patternsearch', 'Display', 'off', ...
'FunctionTolerance', 1e-10, 'StepTolerance', 1e-10);
[params, residual, exitflag] = patternsearch(obj_func, init_guess, [], [], [], [], [], [], options);
if exitflag <= 0
error('优化失败');
end
end
% 计算拟合优度
y_fit = model_func(params, x);
SS_tot = sum((y - mean(y)).^2);
SS_res = sum((y - y_fit).^2);
rsquare = max(0, 1 - SS_res/SS_tot); % 确保非负
results{i, 1} = params;
results{i, 2} = rsquare;
results{i, 3} = residual;
catch
results{i, 1} = [];
results{i, 2} = -inf;
results{i, 3} = inf;
end
end
% 选择最佳结果
best_rsquare = -inf;
best_idx = 1;
for i = 1:length(algorithms)
if results{i, 2} > best_rsquare && ~isempty(results{i, 1})
best_rsquare = results{i, 2};
best_idx = i;
end
end
if best_rsquare == -inf
best_params = [];
best_gof = [];
best_func = [];
return;
end
best_params = results{best_idx, 1};
best_func = @(x) model_func(best_params, x);
% 计算完整的拟合优度指标
y_fit = best_func(x);
SS_tot = sum((y - mean(y)).^2);
SS_res = sum((y - y_fit).^2);
rsquare = max(0, 1 - SS_res/SS_tot);
n = length(y);
p = length(init_guess);
adjrsquare = 1 - (1-rsquare)*(n-1)/(n-p-1);
rmse = sqrt(mean((y - y_fit).^2));
mae = mean(abs(y - y_fit));
best_gof = struct('rsquare', rsquare, 'adjrsquare', adjrsquare, ...
'rmse', rmse, 'mae', mae, 'ss_res', SS_res, ...
'algorithm', algorithms{best_idx});
fprintf('最佳算法: %s, R² = %.6f\n', algorithms{best_idx}, rsquare);
end
function params_original = transformParamsToOriginalScale(params, scaling_factors, model_type, param_names)
% 将标准化后的参数转换回原始尺度
params_original = params;
switch model_type
case '二次函数(y=a*x^2+b*x+c)'
% y = a*x^2 + b*x + c
% 标准化: y' = (y - μ_y)/σ_y, x' = (x - μ_x)/σ_x
% 转换: a = a' * σ_y / σ_x^2
% b = b' * σ_y / σ_x - 2 * a' * μ_x * σ_y / σ_x^2
% c = c' * σ_y + μ_y - b' * μ_x * σ_y / σ_x + a' * μ_x^2 * σ_y / σ_x^2
params_original(1) = params(1) * scaling_factors.y_std / (scaling_factors.x_std^2);
params_original(2) = params(2) * scaling_factors.y_std / scaling_factors.x_std - ...
2 * params(1) * scaling_factors.x_mean * scaling_factors.y_std / (scaling_factors.x_std^2);
params_original(3) = params(3) * scaling_factors.y_std + scaling_factors.y_mean - ...
params(2) * scaling_factors.x_mean * scaling_factors.y_std / scaling_factors.x_std + ...
params(1) * scaling_factors.x_mean^2 * scaling_factors.y_std / (scaling_factors.x_std^2);
case '指数函数(y=a*exp(b*x))'
% y = a * exp(b*x)
% 转换: a = a' * σ_y * exp(b' * μ_x / σ_x) + μ_y
% b = b' / σ_x
params_original(1) = params(1) * scaling_factors.y_std * exp(params(2) * scaling_factors.x_mean / scaling_factors.x_std) + scaling_factors.y_mean;
params_original(2) = params(2) / scaling_factors.x_std;
otherwise
% 对于自定义模型,使用近似转换
warning('自定义模型的参数尺度转换可能不精确');
end
end
function visualizeEnhancedResults(fit_result, params_opt, data, gof, optimal_point, model_formula, model_type)
% 增强的结果可视化
fig = figure('Position', [100, 100, 1200, 800], 'Name', '高精度分析结果', 'NumberTitle', 'off');
% 子图1:数据拟合
subplot(2,2,1);
scatter(data.x, data.y, 50, 'bo', 'filled', 'DisplayName', '原始数据');
hold on;
x_fine = linspace(min(data.x), max(data.x), 500);
y_fine = fit_result(x_fine);
plot(x_fine, y_fine, 'r-', 'LineWidth', 2, 'DisplayName', '拟合曲线');
title(sprintf('模型拟合 (R² = %.6f)', gof.rsquare));
xlabel(sprintf('x [%s]', data.x_dim));
ylabel(sprintf('y [%s]', data.y_dim));
legend('Location', 'best');
grid on;
box on;
% 子图2:残差分析
subplot(2,2,2);
y_pred = fit_result(data.x);
residuals = data.y - y_pred;
scatter(y_pred, residuals, 50, 'filled', 'DisplayName', '残差');
hold on;
plot([min(y_pred), max(y_pred)], [0, 0], 'r--', 'LineWidth', 2, 'DisplayName', '零线');
title('残差分析');
xlabel('预测值');
ylabel('残差');
legend('Location', 'best');
grid on;
% 子图3:最优解分析
subplot(2,2,3);
x_range = [min(data.x), max(data.x)];
x_ext = linspace(x_range(1)-(x_range(2)-x_range(1))*0.2, ...
x_range(2)+(x_range(2)-x_range(1))*0.2, 500);
y_ext = fit_result(x_ext);
plot(x_ext, y_ext, 'b-', 'LineWidth', 2, 'DisplayName', '拟合曲线');
hold on;
% 标记最优解
opt_val = fit_result(optimal_point);
plot(optimal_point, opt_val, 'ro', 'MarkerSize', 10, ...
'MarkerFaceColor', 'r', 'DisplayName', sprintf('最优点 (%.6g, %.6g)', optimal_point, opt_val));
title('最优解定位');
xlabel(sprintf('x [%s]', data.x_dim));
ylabel(sprintf('f(x) [%s]', data.y_dim));
legend('Location', 'best');
grid on;
% 子图4:拟合参数和统计信息
subplot(2,2,4);
axis off;
% 创建信息文本
param_text = sprintf('拟合参数:\n');
for i = 1:length(params_opt)
param_text = sprintf('%s%s = %.6g\n', param_text, getParamName(i), params_opt(i));
end
stats_text = sprintf('\n拟合统计:\n');
stats_text = sprintf('%sR² = %.6f\n', stats_text, gof.rsquare);
stats_text = sprintf('%s调整R² = %.6f\n', stats_text, gof.adjrsquare);
stats_text = sprintf('%sRMSE = %.6f\n', stats_text, gof.rmse);
stats_text = sprintf('%sMAE = %.6f\n', stats_text, gof.mae);
stats_text = sprintf('%s算法: %s\n', stats_text, gof.algorithm);
opt_text = sprintf('\n最优解:\nx* = %.6g\nf(x*) = %.6g', optimal_point, opt_val);
text(0.1, 0.8, param_text, 'VerticalAlignment', 'top', 'FontName', 'Courier');
text(0.1, 0.5, stats_text, 'VerticalAlignment', 'top', 'FontName', 'Courier');
text(0.1, 0.2, opt_text, 'VerticalAlignment', 'top', 'FontName', 'Courier');
title('详细结果');
end
function name = getParamName(idx)
% 获取参数名称
names = {'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'};
if idx <= length(names)
name = names{idx};
else
name = sprintf('p%d', idx);
end
end
%% ==================== 保留原有辅助函数(保持不变) ====================
function [data, should_return] = getInitialData()
% 获取初始数据(保持不变)
should_return = false;
while true
input_data = inputdlg({'自变量x (空格分隔):', ...
'因变量y (空格分隔):', ...
'x单位 (如m,kg,s):', ...
'y单位 (如m,kg,s):'}, ...
'数据输入', [1 50; 1 50; 1 10; 1 10]);
if isempty(input_data)
should_return = true;
data = [];
return;
end
try
data.x = str2num(input_data{1});
data.y = str2num(input_data{2});
assert(length(data.x)>=3 && length(data.x)==length(data.y));
% 存储单位信息
data.x_dim = strtrim(input_data{3});
data.y_dim = strtrim(input_data{4});
if isempty(data.x_dim), data.x_dim = '1'; end
if isempty(data.y_dim), data.y_dim = '1'; end
break;
catch
errordlg('输入无效! 需满足: 1) 数值格式 2) 数据点≥3 3) x/y长度一致');
end
end
fprintf('\n=== 输入数据 ===\n');
disp(array2table([data.x; data.y], 'RowNames', {'x', 'y'}));
fprintf('单位信息: x [%s], y [%s]\n', data.x_dim, data.y_dim);
end
function [new_data, should_return] = getAdditionalData(old_data)
% 获取追加数据(保持不变)
should_return = false;
new_data = struct();
prompt = {sprintf('追加的x值 (空格分隔,当前单位[%s]):', old_data.x_dim), ...
sprintf('追加的y值 (空格分隔,当前单位[%s]):', old_data.y_dim)};
while true
input_data = inputdlg(prompt, '追加数据输入', [1 50; 1 50]);
if isempty(input_data)
should_return = true;
return;
end
try
new_data.x = str2num(input_data{1});
new_data.y = str2num(input_data{2});
assert(length(new_data.x)>=1 && length(new_data.x)==length(new_data.y));
new_data.x_dim = old_data.x_dim;
new_data.y_dim = old_data.y_dim;
fprintf('\n=== 追加的数据点 ===\n');
disp(array2table([new_data.x; new_data.y], 'RowNames', {'x_add', 'y_add'}));
break;
catch
errordlg('输入无效! 需满足: 1) 数值格式 2) 数据点≥1 3) x/y长度一致');
end
end
end
function formula = vectorizeMathExpr(formula)
% 数学表达式向量化处理(保持不变)
math_functions = {'sin','cos','tan','exp','log','log10','sqrt',...
'asin','acos','atan','sinh','cosh','tanh'};
ops = {'*','/','^'};
for op = ops
formula = strrep(formula, op{1}, ['.' op{1}]);
end
for fn = math_functions
formula = regexprep(formula, ['\<' fn{1} '\>\s*\('], [fn{1} '(']);
end
end
function [params, guesses] = getParameters(formula)
% 从公式中提取参数(保持不变)
params = {};
guesses = [];
try
reserved_words = {'x','i','j','pi','inf','nan',...
'sin','cos','tan','exp','log','log10','sqrt',...
'asin','acos','atan','sinh','cosh','tanh'};
tokens = regexp(formula, '\<[a-zA-Z][a-zA-Z0-9]*\>', 'match');
vars = unique(tokens);
params = setdiff(vars, reserved_words);
guesses = ones(1, length(params));
if length(params) >= 1
guesses(1) = 1.0;
end
if length(params) >= 2
guesses(2) = 0.1;
end
catch
errordlg('参数提取失败! 请确认使用标准数学函数(如sin(x), 而非sind(x))');
end
end
function modified_formula = replaceSymbols(formula, params)
% 将公式中的参数符号替换为p(1),p(2)等(保持不变)
modified_formula = formula;
for i = 1:length(params)
modified_formula = strrep(modified_formula, params{i}, sprintf('p(%d)',i));
end
end
function optimal_x = refinedQLearningOptimize(fit_result, x_range, model_type)
% 高精度Q学习优化(保持不变)
range_span = max(x_range) - min(x_range);
x_min = min(x_range) - 0.1 * range_span;
x_max = max(x_range) + 0.1 * range_span;
if contains(model_type, '二次函数')
reward_func = @(x) -fit_result(x);
else
reward_func = @(x) fit_result(x);
end
num_states = 500;
states = linspace(x_min, x_max, num_states);
alpha = 0.2;
gamma = 0.95;
epsilon = 0.3;
episodes = 3000;
Q = zeros(num_states, 5);
for ep = 1:episodes
state = randi(num_states);
for step = 1:150
if rand < epsilon
action = randi(5);
else
[~, action] = max(Q(state, :));
end
new_state = state + (action - 3);
new_state = max(1, min(num_states, new_state));
reward = reward_func(states(new_state));
Q(state, action) = Q(state, action) + ...
alpha * (reward + gamma * max(Q(new_state, :)) - Q(state, action));
state = new_state;
end
epsilon = max(0.1, epsilon * 0.999);
end
[~, optimal_state] = max(mean(Q, 2));
optimal_x = states(optimal_state);
try
if contains(model_type, '二次函数')
optimal_x = fminbnd(fit_result, x_min, x_max, ...
optimset('Display', 'off', 'TolX', 1e-10));
else
optimal_x = fminbnd(@(x) -fit_result(x), x_min, x_max, ...
optimset('Display', 'off', 'TolX', 1e-10));
end
catch
fprintf('使用Q学习结果作为最优解\n');
end
opt_val = fit_result(optimal_x);
fprintf('高精度最优解:\n');
fprintf(' - x* = %.10g\n', optimal_x);
fprintf(' - f(x*) = %.10g\n', opt_val);
test_points = linspace(x_min, x_max, 2000);
test_values = fit_result(test_points);
if contains(model_type, '二次函数')
[min_val, min_idx] = min(test_values);
if abs(min_val - opt_val) > 1e-8
optimal_x = test_points(min_idx);
fprintf('调整后最优解: x = %.10g\n', optimal_x);
end
else
[max_val, max_idx] = max(test_values);
if abs(max_val - opt_val) > 1e-8
optimal_x = test_points(max_idx);
fprintf('调整后最优解: x = %.10g\n', optimal_x);
end
end
end
function [is_valid, result] = performDimensionalAnalysis(model_type, x_dim, y_dim)
% 量纲分析(保持不变)
result.message = '';
is_valid = true;
if strcmp(x_dim, '1') && strcmp(y_dim, '1')
result.message = '无量纲数据,跳过量纲检查';
return;
end
if ~isValidDimension(x_dim) || ~isValidDimension(y_dim)
result.message = '量纲格式无效,使用基本量纲组合(如m,kg,s)';
is_valid = false;
return;
end
switch model_type
case '二次函数(y=a*x^2+b*x+c)'
x2_dim = multiplyDimensions(x_dim, x_dim);
if ~areDimensionsCompatible(x2_dim, x_dim)
result.message = sprintf('二次项量纲[%s]与线性项[%s]不兼容', x2_dim, x_dim);
is_valid = false;
return;
end
if ~areDimensionsCompatible(y_dim, x2_dim) && ~areDimensionsCompatible(y_dim, x_dim)
result.message = sprintf('y量纲[%s]与x量纲[%s]不兼容', y_dim, x_dim);
is_valid = false;
else
result.message = sprintf('二次模型量纲兼容: x[1], y[W]');
end
case '指数函数(y=a*exp(b*x))'
if ~strcmp(simplifyDimension(x_dim), '1')
result.message = sprintf('指数函数要求x无量纲,当前x量纲[%s]', x_dim);
is_valid = false;
else
if ~areDimensionsCompatible(y_dim, '1')
result.message = sprintf('a的量纲[%s]必须与y[%s]一致', y_dim, y_dim);
is_valid = false;
else
result.message = '指数模型量纲有效';
end
end
otherwise
result.message = '自定义模型量纲未验证';
end
end
function param_dims = deriveParameterDimensions(x_dim, y_dim, formula)
% 参数量纲推导(保持不变)
param_dims = {};
if contains(formula, 'a*x^2 + b*x + c')
param_dims{1} = divideDimensions(y_dim, multiplyDimensions(x_dim, x_dim));
param_dims{2} = divideDimensions(y_dim, x_dim);
param_dims{3} = y_dim;
elseif contains(formula, 'a*exp(b*x)')
param_dims{1} = y_dim;
param_dims{2} = divideDimensions('1', x_dim);
else
tokens = regexp(formula, '\<[a-zA-Z][a-zA-Z0-9]*\>', 'match');
params = unique(tokens);
params = setdiff(params, {'x'});
for i = 1:length(params)
param_dims{i} = y_dim;
end
end
end
function verifyDimensionalConsistency(params, param_names, param_dims, x_dim, y_dim, formula)
% 量纲一致性验证(保持不变)
fprintf('参数量纲一致性验证:\n');
for i = 1:length(params)
param_value = params(i);
param_dim = param_dims{i};
fprintf(' - %s [%s]: 值=%.8g', param_names{i}, param_dim, param_value);
if strcmp(param_names{i}, 'b') && contains(formula, 'exp(b*x)')
if ~strcmp(simplifyDimension(param_dim), '1')
fprintf(' (警告: 指数项应有量纲1)');
end
end
fprintf('\n');
end
if ~strcmp(y_dim, '1')
fprintf('模型输出量纲应匹配: [%s]\n', y_dim);
end
end
function valid = isValidDimension(dim)
% 量纲验证(保持不变)
valid = isempty(regexp(dim, '[^a-zA-Z0-9*/^]', 'once'));
end
function result = multiplyDimensions(dim1, dim2)
% 量纲相乘(保持不变)
if strcmp(dim1, '1'), result = dim2; return; end
if strcmp(dim2, '1'), result = dim1; return; end
result = [dim1 '*' dim2];
end
function result = divideDimensions(dim1, dim2)
% 量纲相除(保持不变)
if strcmp(dim2, '1'), result = dim1; return; end
if strcmp(dim1, dim2), result = '1'; return; end
result = [dim1 '/' dim2];
end
function compatible = areDimensionsCompatible(dim1, dim2)
% 量纲兼容性检查(保持不变)
compatible = strcmp(simplifyDimension(dim1), simplifyDimension(dim2));
end
function simple_dim = simplifyDimension(dim)
% 量纲简化(保持不变)
if isempty(dim) || strcmp(dim, '1'), simple_dim = '1'; return; end
try
dim = regexprep(dim, '(\w+)\^(\d+)', '${repmat([$1 ''*''],1,str2num($2)-1)}$1');
parts = strsplit(dim, '/');
if length(parts) == 1
num = parts{1};
den = '1';
else
num = parts{1};
den = parts{2};
end
num_terms = strsplit(num, '*');
den_terms = strsplit(den, '*');
num_terms(cellfun(@isempty, num_terms)) = [];
den_terms(cellfun(@isempty, den_terms)) = [];
num_terms = sort(num_terms);
den_terms = sort(den_terms);
i = 1; j = 1;
while i <= length(num_terms) && j <= length(den_terms)
if strcmp(num_terms{i}, den_terms{j})
num_terms(i) = [];
den_terms(j) = [];
elseif strcmp(num_terms{i}, den_terms{j})
i = i + 1;
j = j + 1;
elseif num_terms{i} < den_terms{j}
i = i + 1;
else
j = j + 1;
end
end
if isempty(num_terms) && isempty(den_terms)
simple_dim = '1';
elseif isempty(den_terms)
simple_dim = strjoin(unique(num_terms), '*');
else
simple_dim = [strjoin(unique(num_terms), '*') '/' strjoin(unique(den_terms), '*')];
end
if strcmp(simple_dim, '1/1'), simple_dim = '1'; end
if endsWith(simple_dim, '*'), simple_dim = simple_dim(1:end-1); end
if endsWith(simple_dim, '/'), simple_dim = [simple_dim '1']; end
catch
simple_dim = dim;
end
end