山東大學機器學習(實驗五解讀)——SVM
1.SVM
(1)這里我選擇懲罰系數做實驗,不同的懲罰系數可能導致結果不同。
- 核函數kernel.m
function K = kernel(X,Y,type,gamma)
switch type
case 'linear' %線性核
K = X*Y';
case 'rbf' %高斯核
m = size(X,1);
K = zeros(m,m);
for i = 1:m
for j = 1:m
K(i,j) = exp(-gamma*norm(X(i,:)-Y(j,:))^2);
end
end
end
end
- 訓練函數svmTrain.m
function svm = svmTrain(X,Y,kertype,gamma,C)
%二次規劃問題,使用quadprog,詳細help quadprog
n = length(Y);
H = (Y*Y').*kernel(X,X,kertype,gamma);
f = -ones(n,1);
A = [];
b = [];
Aeq = Y';
beq = 0;
lb = zeros(n,1);
ub = C*ones(n,1);
a = quadprog(H,f,A,b,Aeq,beq,lb,ub);
epsilon = 3e-5; %閾值可以根據自身需求選擇
%找出支持向量
svm_index = find(abs(a)> epsilon);
svm.sva = a(svm_index);
svm.Xsv = X(svm_index,:);
svm.Ysv = Y(svm_index);
svm.svnum = length(svm_index);
svm.a = a;
end
- 預測函數predict1.m
function test = predict1(train_data_name,test_data_name,kertype,gamma,C)
%(1)-------------------training data ready-------------------
train_data = load(train_data_name);
n = size(train_data,2); %data column
train_x = train_data(:,1:n-1);
train_y = train_data(:,n);
%find the position of positive label and negtive label
pos = find ( train_y == 1 );
neg = find ( train_y == -1 );
figure('Position',[400 400 1000 400]);
subplot(1,2,1);
plot(train_x(pos,1),train_x(pos,2),'k+');
hold on;
plot(train_x(neg,1),train_x(neg,2),'bs');
hold on;
%(2)-----------------decision boundary-------------------
train_svm = svmTrain(train_x,train_y,kertype,gamma,C);
%plot the support vector
plot(train_svm.Xsv(:,1),train_svm.Xsv(:,2),'ro');
train_a = train_svm.a;
train_w = [sum(train_a.*train_y.*train_x(:,1));sum(train_a.*train_y.*train_x(:,2))];
train_b = sum(train_svm.Ysv-train_svm.Xsv*train_w)/size(train_svm.Xsv,1);
train_x_axis = 0:1:200;
plot(train_x_axis,-train_b-train_w(1,1)*train_x_axis/train_w(2,1),'-');
legend('1','-1','suport vector','decision boundary');
title('training data')
hold on;
%(3)-------------------testing data ready----------------------
test_data = load(test_data_name);
m = size(test_data,2); %data column
test_x = test_data(:,1:m-1);
test_y = test_data(:,m);
%find the test data positive label and negtive label
test_label = sign(test_x*train_w + train_b);
subplot(1,2,2);
test_pos = find ( test_y == 1 );
test_neg = find ( test_y == -1 );
plot(test_x(test_pos,1),test_x(test_pos,2),'k+');
hold on;
plot(test_x(test_neg,1),test_x(test_neg,2),'bs');
hold on;
%(4)------------------decision boundary -----------------------
test_x_axis = 0:1:200;
plot(test_x_axis,-train_b-train_w(1,1)*test_x_axis/train_w(2,1),'-');
legend('1','-1','decision boundary');
title('testing data');
%print the detail
fprintf('--------------------------------------------\n');
fprintf('training_data: %s\n',train_data_name);
fprintf('testing_data: %s\n',test_data_name);
fprintf('C = %d\n',C);
fprintf('number of test data label: %d\n',size(test_data,1));
fprintf('predict corret number of test data label: %d\n',length(find(test_label==test_y)));
fprintf('Success rate: %.4f\n',length(find(test_label==test_y))/size(test_data,1));
fprintf('--------------------------------------------\n');
end
- 主函數part1.m
kertype = 'linear';
gamma = 0; C = 1;
predict1('training_1.txt','test_1.txt',kertype,gamma,C);
predict1('training_2.txt','test_2.txt',kertype,gamma,C);
做出圖像如下
- training_1.txt 和 test_1.txt
- training_2.txt 和 test_2.txt
(2)測試數據的預測結果如下,可以由上面第(1)問中的代碼獲得
training_data: training_1.txt
testing_data: test_1.txt
C = 1
number of test data label: 500
predict corret number of test data label: 500
Success rate: 1.0000
training_data: training_2.txt
testing_data: test_2.txt
C = 1
number of test data label: 500
predict corret number of test data label: 500
Success rate: 1.0000
可以發現測試數據的正確率是百分之百,說明決策邊界剛好可以把正負樣本給完全分隔開。
(3)更改第(1)問中的主函數,其余的代碼不變
- part1_3.m
kertype = 'linear';
gamma = 0;
C = [0.01,0.1,1,10,100];
for i=1:size(C,2)
predict1('training_1.txt','test_1.txt',kertype,gamma,C(i));
end
for i=1:size(C,2)
predict1('training_2.txt','test_2.txt',kertype,gamma,C(i));
end
觀察結果
- training_1.txt 和 test_1.txt
C | success rate |
---|---|
0.01 | 1.0000 |
0.1 | 1.0000 |
1 | 1.0000 |
10 | 1.0000 |
100 | 1.0000 |
- training_2.txt 和 test_2.txt
C | success rate |
---|---|
0.01 | 1.0000 |
0.1 | 1.0000 |
1 | 1.0000 |
10 | 1.0000 |
100 | 1.0000 |
從上面的結果看似乎沒差別,這是因為我們的數據集太好了,所以預測正確率都是1。正常的情況下,C比較大的時候表示我們想要犯更少的錯誤,但margin會稍微小一點;C比較的小的時候表示我們想要更大的margin,劃分錯誤多一點沒關系。也就是說C比較大時正確率會高一些,而C比較小正確率會低一些。
2. 手寫字體識別
實驗給的strimage.m的作用是從train-01-images.svm中選定指定行以圖片的形式顯示0或1,其中圖片的像素是28×28。
(1)以下兩個m文件處理train-01-images.svm和test-01-images.svm,其功能是參照strimage.m將數據集每一行先轉換為28×28像素的圖像,圖像的每一點有確定的值。然后將所有的點展開成一行全部存儲于hand_digists_train.dat和hand_digists_test.dat。轉換后的hand_digists_train.dat大小為12665×785,hand_digists_test.dat的大小為2115×785,每一行中的最后一個數字對應著標簽,即第785列。
- re_hand_digits.m
function svm = re_hand_digits(filename,n)
fidin = fopen(filename);
i = 1;
apres = [];
while ~feof(fidin)
tline = fgetl(fidin); % 從文件讀行
apres{i} = tline;
i = i+1;
end
grid = zeros(n,784);
label = zeros(n,1);
for k = 1:n
a = char(apres(k));
lena = size(a,2);
xy = sscanf(a(4:lena), '%d:%d');
label(k,1) = sscanf(a(1:3),'%d');
lenxy = size(xy,1);
for i=2:2:lenxy %% 隔一個數
if(xy(i)<=0)
break
end
grid(k,xy(i-1)) = xy(i) * 100/255; %轉為有顏色的圖像
end
end
svm.grid = grid;
svm.label = label;
end
- save_data.m
svm1 = re_hand_digits('train-01-images.svm',12665);
svm2 = re_hand_digits('test-01-images.svm',2115);
train_x = svm1.grid; train_y = svm1.label;
test_x = svm2.grid; test_y = svm2.label;
train = [train_x,train_y];
test = [test_x,test_y];
[row,col] = size(test);
fid=fopen('hand_digits_test.dat','wt');
for i=1:1:row
for j=1:1:col
if(j==col)
fprintf(fid,'%g\n',test(i,j));
else
fprintf(fid,'%g\t',test(i,j));
end
end
end
fclose(fid);
[row,col] = size(train);
fid=fopen('hand_digits_train.dat','wt');
for i=1:1:row
for j=1:1:col
if(j==col)
fprintf(fid,'%g\n',train(i,j));
else
fprintf(fid,'%g\t',train(i,j));
end
end
end
fclose(fid);
實驗雖然要求使用全部的訓練集進行訓練,但你會發現使用全部訓練集的時候,使用quadprog函數會出現內存不足的情況,當然好配置的電腦可以跑完,但我們班的大多數人都是跑內存爆炸的,所以這里我只采用大小為3000的訓練集,測試集數量不變,你可以根據自身需求選擇,當然這對實驗的結果是沒多大影響的。
- strimage.m
實驗雖然給了strimage.m,但我們還是有必要更改一下這個代碼,因為這個代碼只能查看訓練集的0或1的圖像,也就是train-01-images.svm中的圖像,我們的目的是它也能查看test-01-images.svm的0或1圖像,這個代碼其實就加了一個參數filename而已。
function strimage(filename,n)
fidin = fopen(filename);
i = 1;
apres = [];
while ~feof(fidin)
tline = fgetl(fidin); % 從文件讀行
apres{i} = tline;
i = i+1;
end
%選中我們選定的第n張圖片
a = char(apres(n));
lena = size(a);
lena = lena(2);
%xy存儲像素的索引和相應的灰度值
xy = sscanf(a(4:lena), '%d:%d');
lenxy = size(xy);
lenxy = lenxy(1);
grid = [];
grid(784) = 0; %28*28網格,0代表黑色背景
for i=2:2:lenxy %% 隔一個數
if(xy(i)<=0)
break
end
grid(xy(i-1)) = xy(i) * 100/255; %轉為有顏色的圖像
end
%顯示手寫數字圖像
grid1 = reshape(grid,28,28);
grid1 = fliplr(diag(ones(28,1)))*grid1;
grid1 = rot90(grid1,3);
image(grid1);
hold on;
end
- 預測函數predict2.m
function [test_miss,train_miss] = predict2(train_data_name,test_data_name,kertype,gamma,C)
%(1)-------------------training data ready-------------------
train_data = train_data_name;
n = size(train_data,2);
train_x = train_data(:,1:n-1);
train_y = train_data(:,n);
%(2)-----------------training model-------------------
%二次規劃用來求解問題,使用quadprog
n = length(train_y);
H = (train_y*train_y').*kernel(train_x,train_x,kertype,gamma);
f = -ones(n,1); %f'為1*n個-1
A = [];
b = [];
Aeq = train_y';
beq = 0;
lb = zeros(n,1);
if C == 0 %無正則項
ub = [];
else %有正則項
ub = C.*ones(n,1);
end
train_a = quadprog(H,f,A,b,Aeq,beq,lb,ub);
epsilon = 2e-7;
%找出支持向量
sv_index = find(abs(train_a)> epsilon);
Xsv = train_x(sv_index,:);
Ysv = train_y(sv_index);
svnum = length(sv_index);
train_w(1:784,1) = sum(train_a.*train_y.*train_x(:,1:784));
train_b = sum(Ysv-Xsv*train_w)/svnum;
train_label = sign(train_x*train_w+train_b);
train_miss = find(train_label~=train_y);
%(3)-------------------testing data ready----------------------
test_data = test_data_name;
m = size(test_data,2);
test_x = test_data(:,1:m-1);
test_y = test_data(:,m);
test_label = sign(test_x*train_w+train_b);
test_miss = find(test_label~=test_y);
%(4)------------------detail -----------------------;
%print the detail
fprintf('--------------------------------------------\n');
fprintf('C = %d\n',C);
fprintf('number of test data label: %d\n',size(test_data,1));
fprintf('number of train data label: %d\n',size(train_data,1));
fprintf('predict corret number of test data label: %d\n',length(find(test_label==test_y)));
fprintf('predict corret number of train data label: %d\n',length(find(train_label==train_y)));
fprintf('Success rate of test data: %.4f\n',length(find(test_label==test_y))/size(test_data,1));
fprintf('Success rate of train data: %.4f\n',length(find(train_label==train_y))/size(train_data,1));
fprintf('--------------------------------------------\n');
end
- 主函數part2.m
train_data = load('hand_digits_train.dat');
test_data = load('hand_digits_test.dat');
train_len = size(train_data,1); test_len = size(test_data,1);
train_index = randperm(train_len,3000);
test_index = randperm(test_len,2115);
train_select = zeros(3000,785);
test_select = zeros(2115,785);
for i = 1:3000
train_select(i,:) = train_data(train_index(i),:);
end
for i = 1:2115
test_select(i,:) = test_data(test_index(i),:);
end
%無正則項
[train_miss_index,test_miss_index] = predict2(train_select,test_select,'linear',0,0);
%查看被錯誤分類的手寫字體
%訓練錯誤手寫字體
for i = 1:length(train_miss_index)
strimage('train-01-images.svm',i);
figure;
end
%測試錯誤手寫字體
for i = 1:length(test_miss_index)
strimage('test-01-images.svm',i);
if i~=length(test_miss_index)
figure;
end
end
- 打印訓練錯誤和測試錯誤
C = 0
number of test data label: 2115
number of train data label: 3000
predict corret number of test data label: 2108
predict corret number of train data label: 2991
Success rate of test data: 0.9967
Success rate of train data: 0.9970
- 附上一張訓練錯誤圖像和測試錯誤圖像(左邊是訓練錯誤圖像,右邊是測試錯誤圖像)
可以發現,錯誤的和書寫不規范有很大的關系。
(2)這里主要改的代碼是主函數part2.m,其余的不用變。
train_data = load('hand_digits_train.dat');
test_data = load('hand_digits_test.dat');
train_len = size(train_data,1); test_len = size(test_data,1);
train_index = randperm(train_len,3000);
test_index = randperm(test_len,2115);
train_select = zeros(3000,785);
test_select = zeros(2115,785);
for i = 1:3000
train_select(i,:) = train_data(train_index(i),:);
end
for i = 1:2115
test_select(i,:) = test_data(test_index(i),:);
end
C = [0.01,0.1,1,10,100];
%有正則項
for i=1:size(C,2)
predict2(train_select,test_select,'linear',0,C(i));
end
觀察結果如下
- 訓練誤差
C | success rate |
---|---|
0.01 | 0.9963 |
0.1 | 0.9973 |
1 | 0.9947 |
10 | 0.9993 |
100 | 0.9920 |
- 測試誤差
C | success rate |
---|---|
0.01 | 0.9939 |
0.1 | 0.9957 |
1 | 0.9910 |
10 | 0.9991 |
100 | 0.9960 |
回答下列問題
(i)從上述結果發現C=100時訓練誤差最大。
(ii)問題(i)中對應的測試誤差為1-0.9960 = 0.0040,在matlab顯示2115個測試數據錯了17個。
(iii)由表中可以看C在10時測試誤差和訓練誤差都很小,所以我推測C=10左右會使得測試誤差很小,當然只是我個人推測。
3. 非線性SVM
- 預測函數predict3.m
function test = predict3(train_name,type,gamma,C)
%-----------------------training data ready------------------------
train = train_name;
[m,n] = size(train);
train_x = train(:,1:n-1);
train_y = train(:,n);
pos = find(train_y == 1);
neg = find(train_y == -1);
plot(train_x(pos,1),train_x(pos,2),'k+');
hold on;
plot(train_x(neg,1),train_x(neg,2),'bs');
hold on;
%-----------------------training model--------------------------
%二次規劃用來求解問題,使用quadprog
K = kernel(train_x,train_x,type,gamma);
H = (train_y*train_y').*K;
f = -ones(m,1);
A = [];
b = [];
Aeq = train_y';
beq = 0;
lb = zeros(m,1);
if C == 0
ub = [];
else
ub = C*ones(m,1);
end
a = quadprog(H,f,A,b,Aeq,beq,lb,ub);
epsilon = 1e-5;
%查找支持向量
sv_index = find(abs(a)> epsilon);
Xsv = train_x(sv_index,:);
Ysv = train_y(sv_index);
svnum = length(sv_index);
%make classfication predictions over a grid of values
xplot = linspace(min(train_x(:,1)),max(train_x(:,1)),100)';
yplot = linspace(min(train_x(:,2)),max(train_x(:,2)),100)';
[X,Y] = meshgrid(xplot,yplot);
vals = zeros(size(X));
%calculate decision value
train_a = a;
sum_b = 0;
for k = 1:svnum
sum = 0;
for i = 1:m
sum = sum + train_a(i,1)*train_y(i,1)*K(i,k);
end
sum_b = sum_b + Ysv(k) - sum;
end
train_b = sum_b/svnum;
for i = 1:100
for j = 1:100
x_y = [X(i,j),Y(i,j)];
sum = 0;
for k = 1:m
sum = sum + train_a(k,1)*train_y(k,1)*exp(-gamma*norm(train_x(k,:)-x_y)^2);
end
vals(i,j) = sum + train_b;
end
end
%plot the SVM boundary
colormap bone;
contour(X,Y,vals,[0 0],'LineWidth',2);
title(['\gamma = ',num2str(gamma)]);
end
- 主函數part3.m
type = 'rbf';
train_name = load('training_3.text');
gamma = [1,10,100,1000];
C = 1;
for i = 1:length(gamma)
predict3(train_name,type,gamma(i),C);
if i ~= length(gamma)
figure;
end
end
- 實驗結果
可以發現γ越大,會出現過擬合現象,而γ越小,會出現欠擬合現象。
智能推薦
3D游戲編程與設計——游戲對象與圖形基礎章節作業與練習
3D游戲編程與設計——游戲對象與圖形基礎章節作業與練習 3D游戲編程與設計——游戲對象與圖形基礎章節作業與練習 自學資源 作業內容 1、基本操作演練【建議做】 天空盒的制作: 地圖的制作: 整體效果: 2、編程實踐 項目要求: 項目結構: 代碼詳解: Actions: ISSActionCallback.cs SSAction.cs SSAction...
FlycoTabLayout 的使用
FlycoTabLayout 一個Android TabLayout庫,目前有3個TabLayout SlidingTabLayout:參照PagerSlidingTabStrip進行大量修改. 新增部分屬性 新增支持多種Indicator顯示器 新增支持未讀消息顯示 新增方法for懶癌患者 CommonTabLayout:不同于SlidingTabLayout對ViewPager依賴,它是一個不...
爬蟲項目實戰八:爬取天氣情況
爬取天氣情況 目標 項目準備 接口分析 代碼實現 效果顯示 寫入本地 目標 根據天氣接口,爬取接下來一周的天氣情況。 項目準備 軟件:Pycharm 第三方庫:requests,BeautifulSoup,csv 接口地址:http://api.k780.com:88/?app=weather.future&weaid=城市名&appkey=10003&sign=b59bc...
關于web項目的目錄問題
先給段代碼: 上面這個代碼一直出錯,我不知道原因,后面不停的查找資料發現了問題:我的web項目輸出目錄有問題,因為我也是第一次用idea寫web項目,發現很多bug 其實都沒有太大問題,我們需要注意的是你必須在out這個輸出文件夾中擁有這個文件,out輸出文件夾會默認過濾這些文件...
二叉搜索樹轉化為雙向鏈表
題目描述: 輸入一棵二叉搜索樹,將該二叉搜索樹轉換成一個排序的循環雙向鏈表。要求不能創建任何新的節點,只能調整樹中節點指針的指向。 為了讓您更好地理解問題,以下面的二叉搜索樹為例: 我們希望將這個二叉搜索樹轉化為雙向循環鏈表。鏈表中的每個節點都有一個前驅和后繼指針。對于雙向循環鏈表,第一個節點的前驅是最后一個節點,最后一個節點的后繼是第一個節點。 下圖展示了上面的二叉搜索樹轉化成的鏈表。&ldqu...
猜你喜歡
Cocos2d-x 2.0 網格動畫深入分析
[Cocos2d-x相關教程來源于紅孩兒的游戲編程之路CSDN博客地址:http://blog.csdn.net/honghaier] 紅孩兒Cocos2d-X學習園地QQ2群:44208467加群寫:Cocos2d-x 紅孩兒Cocos2d-X學習園地QQ群:249941957[暫滿]加群寫:Cocos2d-x 本章為我的Cocos2d-x教程一書初稿。望各位看官多提建議! Cocos2d-x ...
解決Python數據可視化中文部分顯示方塊問題
一、問題 代碼如下,發現標題的中文顯示的是方塊 如下圖 二、解決方法 一般數據可視化使用matplotlib庫,設置中文字體可以在導入之后添加兩句話(這里的SimHei指的是黑體,KaiTi指的是楷體) 三、效果 1.黑體: 2.楷體: 具體的其他字體可以在matplotlib\mpl-data\fonts\ttf找到~ 四、Windows的常用字體 黑體、楷體、仿宋是可以用的,其他的字體可能需要...
Linux的LVM掛載(Centos)
LVM掛載 1、虛擬機添加新增磁盤(如已添加可略過) 2、查看是否有新的硬盤 3、對磁盤分區 4、LVM磁盤創建 參考地址: https://blog.51cto.com/11555417/2158443 1、虛擬機添加新增磁盤(如已添加可略過) 1.點擊虛擬機,選擇硬盤,點擊添加,選擇SCSI硬盤,添加硬盤(如下圖所示)。 2、查看是否有新的硬盤 可以看到 /dev/sdb 是我們新建的磁盤5G...
Java四大元注解介紹
Java四大元注解介紹 什么是元注解? 元注解就是注解到注解上的注解。它們被用來提供對其它 annotation類型作說明。 Java5.0定義的元注解: @Retention、@Documented、@Target、@Inherited,這些類型和它們所支持的類在java.lang.annotation包中可以找到。如圖所示: 接下來我們看一下每個元注解的作用和相應分參數的使用說明。 @Docu...