Article Image
K最近邻(K-NN)方法在Titanic数据上的实现
1 K-NN
kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。
1.1 优点
- 简单,易于理解,易于实现,无需估计参数,无需训练;
- 适合对稀有事件进行分类;
- 特别适合于多分类问题(multi-modal,对象具有多个类别标签),kNN比SVM的表现要好。
1.2 缺点
- 懒惰算法,对测试样本分类时的计算量大,内存开销大,评分慢;
- 当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有 可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数;
- 可解释性较差,无法给出决策树那样的规则。
2 Exploring data
2.1 Data cleaning
数据来自https://www.kaggle.com/c/titanic/data,数据包括Tiatanic上Passanger的相关信息,我们的目的是通过Passanger的其他信息判断Passanger是否幸存。
library(dplyr) # data processing
library(ggplot2) # visualization
library(rpart) # decision tree
library(caret) # knn estimation
train = read.csv("./Titanic/train.csv",na.strings = "")
test = read.csv("./Titanic/test.csv",na.strings = "")
gender = read.csv("./Titanic/gendermodel.csv",na.strings = "")
all = bind_rows(train, test)
观察数据集中每一个变量的属性:
sapply(all,class)
## PassengerId Survived Pclass Name Sex Age
## "integer" "integer" "integer" "character" "factor" "numeric"
## SibSp Parch Ticket Fare Cabin Embarked
## "integer" "integer" "character" "numeric" "character" "factor"
以及各个变量是否存在异常值:
sapply(all,FUN = function(x) sum(is.na(x)))
## PassengerId Survived Pclass Name Sex Age
## 0 418 0 0 0 263
## SibSp Parch Ticket Fare Cabin Embarked
## 0 0 0 1 1014 2
“PassangerID”指每一个观测者的编号,没有实际意义,舍弃“Ticket”,“Name”和“cabin”变量。
all = all[,!names(all) %in% c("PassengerId","Name","Ticket","Cabin")]
第62和830个观测者的“Embarked”的值为缺失值,观察这两个样本信息:
which(is.na(all$Embarked))
## [1] 62 830
all[c(62,830),]
## Source: local data frame [2 x 8]
##
## Survived Pclass Sex Age SibSp Parch Fare Embarked
## 1 1 1 female 38 0 0 80 NA
## 2 1 1 female 62 0 0 80 NA
ggplot(data = all[-c(62,830),],aes(x=Embarked,y=Fare))+geom_boxplot(aes(fill=factor(Pclass)))+geom_hline(aes(yintercept=80),col = "red",linetype = 'dashed',lwd=2)
## Warning: Removed 1 rows containing non-finite values (stat_boxplot).
通过图像,可以补充第62和830的“Embarked”的值为“C”
all[c(62,830),]$Embarked = "C"
观察“Fare”值缺失的观测者信息,并根据相关信息填补缺失值
which(is.na(all$Fare))
## [1] 1044
all[1044,]
## Source: local data frame [1 x 8]
##
## Survived Pclass Sex Age SibSp Parch Fare Embarked
## 1 NA 3 male 60.5 0 0 NA S
ggplot(data = all[-1044 & all$Pclass == '3'& all$Embarked == 'S',],aes(x=Fare))+geom_density(fill="grey")+geom_vline(aes(xintercept=median(Fare,na.rm = T)),col = "red",linetype='dashed',lwd=1)
## Warning: Removed 1 rows containing non-finite values (stat_density).
all[1044,]$Fare = median(all$Fare,na.rm = T)
“Age”变量包含大量缺失值,因此希望用决策树的方法通过其他变量预测出“Age”变量的值
factor_vars <- c('Pclass','Sex','Embarked')
all[factor_vars] = lapply(all[factor_vars],function(x) as.factor(x))
rpart_model = rpart(Age~Pclass+Sex+SibSp+Parch+Fare+Embarked,data = all[!is.na(all$Age),])
Age_predict <- predict(rpart_model,all[is.na(all$Age),c("Pclass","Sex","SibSp","Parch","Fare","Embarked")],type = "vector")
Age = append(all$Age[!is.na(all$Age)],Age_predict)
# Plot age distributions
par(mfrow=c(1,2))
hist(all$Age, freq=F, main='Age: Original Data',
col='darkgreen', ylim=c(0,0.04))
hist(Age, freq=F, main='Age: rpart Output',
col='lightgreen', ylim=c(0,0.06))
# 用rpart预测后的Age值代替原来的缺失值
all$Age[is.na(all$Age)] = Age_predict
sapply(all,FUN = function(x) sum(is.na(x)))
## Survived Pclass Sex Age SibSp Parch Fare Embarked
## 418 0 0 0 0 0 0 0
Wow,终于没有缺失值了,现在我们开始决定别人的生死了!
2.2 Change the variable into dummy variables
对于因子型的变量,使用KNN方法时需要进行处理,将其转化为哑变量的形式
### change factor into dummy variable by model.matrix function
change_factor <- function(x){
if(class(x) == "factor"){
data <- data.frame(x=x)
output <- model.matrix(~x,data=data)[,-1]
}else{
output <- x
}
}
### change factor into dummy variable by class.ind in nnet package
# library("nnet")
# dummy_Pclass <- class.ind(train$Pclass)
all_1 <- sapply(all[,-1], FUN = change_factor)
all_1 <- do.call(cbind, all_1)
all_1 <- apply(all_1,2,scale)
all <- data.frame(Survived=all[,1],as.data.frame(all_1))
head(all)
## Survived x2 x3 Sex Age SibSp
## 1 0 -0.5178859 0.9195737 0.7432129 -0.5764224 0.4811039
## 2 1 -0.5178859 -1.0866296 -1.3444816 0.6330570 0.4811039
## 3 1 -0.5178859 0.9195737 -1.3444816 -0.2740525 -0.4789037
## 4 1 -0.5178859 -1.0866296 -1.3444816 0.4062796 0.4811039
## 5 0 -0.5178859 0.9195737 0.7432129 0.4062796 -0.4789037
## 6 0 -0.5178859 0.9195737 0.7432129 -0.1658179 -0.4789037
## Parch Fare xQ xS
## 1 -0.4448295 -0.5030988 -0.3219173 0.6571424
## 2 -0.4448295 0.7344629 -0.3219173 -1.5205776
## 3 -0.4448295 -0.4900532 -0.3219173 0.6571424
## 4 -0.4448295 0.3830371 -0.3219173 0.6571424
## 5 -0.4448295 -0.4876373 -0.3219173 0.6571424
## 6 -0.4448295 -0.4797462 3.1040152 -1.5205776
3 KNN method for classification
After scaling and convert into dummy variables:
经过标准化和数据转换之后的预测准确率: ## knn for k=5
train_knn <- all[1:nrow(train),]
test_knn <- all[-c(1:nrow(train)),]
knn_model <- knn3(Survived~.,data = train_knn, k=5)
knn_predict <- predict(knn_model,test_knn,type="prob")
knn_predict <- ifelse(knn_predict[,1] > 0.5,0,1)
table(knn_predict,gender[,2])
##
## knn_predict 0 1
## 0 230 25
## 1 36 127
mean(knn_predict == gender[,2])
## [1] 0.854067
3.1 KNN cv for k
cv.knn <- function(data,n = 5,k = 5){
index = sample(1:5,replace = T,size = nrow(train_knn))
a.e = 0
for(i in 1:n){
train = data[index == i,]
test = data[!index == i,]
knn_model = knn3(Survived~.,data = train,k = k)
knn_predict = predict(knn_model,test,type = "prob")
knn_predict = ifelse(knn_predict[,1] > 0.5,0,1)
a.e[i] = mean(knn_predict == test$Survived)
}
mean(a.e)
}
cv.knn(train_knn)
## [1] 0.7797442
k = 2:20
set.seed(1234)
a.e = sapply(k,function(x) cv.knn(train_knn, n = 5,k = x))
plot(k,a.e,type = "b")
k.select = which.max(a.e)
knn_model <- knn3(Survived~.,data = train_knn, k=k.select)
knn_predict <- predict(knn_model,test_knn,type="prob")
knn_predict <- ifelse(knn_predict[,1] > 0.5,0,1)
table(knn_predict,gender[,2])
##
## knn_predict 0 1
## 0 228 20
## 1 38 132
mean(knn_predict == gender[,2])
## [1] 0.861244