主代码

% function main()
clc
clear
close all
%% 1.读取数据
%三训一,三个行向量对一个行向量
%训练集
TR1=[0,0,4,4,1,0,1,4,0,1,0,0,0,0,0,0,0,1,0,0;0,0,990,1081,184,0,486,795,0,223,0,0,0,0,0,0,0,198,0,0;0,0,403,363,184,0,486,215,0,223,0,0,0,0,0,0,0,198,0,0];
TAG1=[1,1,1,1,1,1,3,3,3,3,2,2,2,2,2,2,4,4,4,4];
%验证集
TE2=[1,9,6,1,0,5,3,3,6,0,8,0,0,2,0,0,2,3,0,4;345,3256,1757,237,0,1183,785,1617,2520,0,4820,0,0,376,0,0,590,1428,0,1082;345,838,403,237,0,329,319,590,594,0,1241,0,0,196,0,0,338,761,0,414];
TAG2=[1,3,3,1,1,2,2,2,3,1,3,1,1,2,1,2,2,2,1,1];

RFRFRF(TR1,TAG1,TE2,TAG2)

RFRFRF.m

function RFRFRF(TR1,TAG1,TE2,TAG2)
%% 2.数据操作
%%  2.2.划分训练集和测试集
PN = TR1;%训练集输入
TN = TAG1;%训练集输出
PM = TE2;%测试集输入
TM = TAG2;%测试集输出
%%  2.3.数据归一化
[pn, ps_input] = mapminmax(PN, 0, 1);%归一化到(0,1)
pn=pn';
pm = mapminmax('apply', PM, ps_input);%引用结构体,保持归一化方法一致;
pm=pm';
[tn, ps_output] = mapminmax(TN, 0, 1);
tn=tn';

%%  3.模型参数设置及训练模型
trees = 100; % 决策树数目
leaf  = 20; % 最小叶子数
OOBPrediction = 'on';  % 打开误差图
OOBPredictorImportance = 'on'; % 计算特征重要性
Method = 'regression';  % 选择回归或分类
net = TreeBagger(trees, pn, tn, 'OOBPredictorImportance', OOBPredictorImportance,...
      'Method', Method, 'OOBPrediction', OOBPrediction, 'minleaf', leaf);
% importance = net.OOBPermutedPredictorDeltaError;  % 重要性

%%  4.仿真测试
pyuce = predict(net, pm );
Pyuce = mapminmax('reverse', pyuce, ps_output);%数据反归一化
Pyuce =Pyuce';
% result_pre=round(Pyuce);%预测结果
result_pre=Pyuce;%预测结果
correct1=length(find((result_pre()-ones())==0))/100;
disp(['测试集的正确率为:',  num2str(correct1)])

%%  绘图
figure() %画图真实值与预测值对比图
plot(TM,'bo-','markersize',10,'LineWidth',1)
hold on
plot(result_pre,'r*-','markersize',10,'LineWidth',1)
hold on
legend('真实值','预测值')
xlabel('预测样本')
ylabel('预测结果')

%%  相关指标计算
Pyuce=round(Pyuce);%预测结果
error=Pyuce-TM;
[~,len]=size(TM);
R2=1-sum((TM-Pyuce).^2)/sum((mean(TM)-TM).^2);%相关性系数
MSE=error*error'/len;%均方误差
RMSE=MSE^(1/2);%均方根误差
disp(['测试集数据的MSE为:', num2str(MSE)])
disp(['测试集数据的MBE为:', num2str(RMSE)])
disp(['测试集数据的R2为:', num2str(R2)])

zql=length(find((result_pre-TM)==0))/300;
disp(['测试集正确率为:', num2str(zql)])

04-06 10:52