2020/04/26
RからPythonへの道(14)
3月、4月と会社の仕事が超多忙で怒涛の日々でした。4月に入って、新型コロナウィルスによる在宅勤務対応等で労務管理時間が激増。ゆっくり息を付く余裕もありませんでした。そのままGWに入って、ほぼ自宅監禁状態なんでしょうね。ストレスをうまく発散しつつ、GWはインドアで何かのんびりと勉強、自己啓発でもしようかなと気持ちを切り替えようとしています。今回は「13. 決定木(分類)(1)」について、RとPythonで計算していきたいと思います。教材はネット上で多く引用されている定番のirisのデータセットです。
まずは、Rのコードです。irisのデータを読み込んで、13〜16行目で学習データ、評価データを7:3に分けました。18行目で学習データを用いてrpart関数で決定木で分類学習し、結果を2種類のグラフに描画しました。その後、25行目で評価データを代入し、28行目以降で性能評価をしました。
# Decision Tree : iris classificationRのコードの結果は以下の通りです。8行目のグラフは以下の通り。赤色、緑色、青色のプロット点はそれぞれ、setosa、versicolor、virginicaを表しています。
library(rpart)
library(partykit)
library(rpart.plot)
head(iris)
# Pairs graph
plot(iris[,1:4], col=c(2,3,4)[iris$Species])
# Analysis
set.seed(100)
df.rows = nrow(iris)
train.rate = 0.7 # training data rate
train.index = sample(df.rows, df.rows * train.rate)
df.train = iris[train.index,] # training data
df.test = iris[-train.index,] # test data
cat("train=", nrow(df.train), "test=", nrow(df.test))
model.rpart = rpart(Species~., data = df.train) # decition tree model
# Graph
plot(as.party(model.rpart)) # pattern 1
rpart.plot(model.rpart , type = 4, extra = 1, digits = 3) # pattern 2
# Predict
pred.rpart = predict(model.rpart, df.test, type = "class")
# Result : cross tabulation table
result = table(pred.rpart, df.test$Species)
print(result)
# Calculation the accuracy
accuracy_prediction = sum(diag(result)) / sum(result)
print(accuracy_prediction)



5行目29行目のクロス集計表の対角成分の19、12、11は正しく分類された本数で、対角成分外の数字の2と1の3本は誤分類された結果でした。正解率は「対角成分(正解数)の和」を「クロス集計表の数字の総和」で割った値なので93.3%ですね。
Sepal.Length Sepal.Width Petal.Length Petal.Width Species
1 5.1 3.5 1.4 0.2 setosa
2 4.9 3.0 1.4 0.2 setosa
3 4.7 3.2 1.3 0.2 setosa
4 4.6 3.1 1.5 0.2 setosa
5 5.0 3.6 1.4 0.2 setosa
6 5.4 3.9 1.7 0.4 setosa
17行目
train= 105 test= 45
29行目
pred.rpart setosa versicolor virginica
setosa 19 0 0
versicolor 0 12 2
virginica 0 1 11
33行目
[1] 0.9333333
次に、Pythonのコードです。流れはRのコードと同じです。
# Decision Tree : iris classificationPythonコードの結果は以下の通りです。23〜24行目のPairプロットグラフは以下の通りです。
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.model_selection import train_test_split
import pydotplus
from sklearn.externals.six import StringIO
from IPython.display import Image
# Read data
iris = load_iris()
df = pd.DataFrame(iris.data)
df.columns = ['Sepal length', 'Sepal width', 'Petal length', 'Petal width']
df['Species'] = iris.target
df.loc[df['Species'] == 0, 'Species'] = "setosa"
df.loc[df['Species'] == 1, 'Species'] = "versicolor"
df.loc[df['Species'] == 2, 'Species'] = "virginica"
print(df.head())
# Pairs graph
g = sns.pairplot(df,hue='Species')
plt.show()
# Analysis
X = df[['Sepal length', 'Sepal width', 'Petal length', 'Petal width']]
y = df['Species']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
print("train=",str(len(X_train)), "test=", str(len(X_test)))
clf = tree.DecisionTreeClassifier(max_depth=2) # decision tree
clf = clf.fit(X_train, y_train)
# Graph
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf('DT_result.pdf')
# Predict
pred = clf.predict(X_test)
# Result : cross tabulation table
print(pd.crosstab(pred, y_test))
# Calculation the accuracy
print(sum(pred == y_test) / len(y_test))



IG = 1 - (34/105)^2 - (32/105)^2 - (39/105)^2 = 0.6643084
20行目正解率は91.1です。Rの結果でもそうでしたが、うまく分類できているのではないかと思いました。用途によっては、学習時のパラメータを過学習しない程度に調整は必要かもしれません。
Sepal length Sepal width Petal length Petal width Species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa
30行目
train= 105 test= 45
44行目
Species setosa versicolor virginica
row_0
setosa 16 0 0
versicolor 0 17 3
virginica 0 1 8
47行目
0.9111111111111111
次回も決定木で分類を行う予定です。
『RからPythonへの道』バックナンバー
(1) はじめに
(2) 0. 実行環境(作業環境)
(3) 1. PythonからRを使う方法 2. RからPythonを使う方法
(4) 3. データフレーム
(5) 4. ggplot
(6) 5.行列
(7) 6.基本統計量
(8) 7. 回帰分析(単回帰)
(9) 8. 回帰分析(重回帰)
(10) 9. 回帰分析(ロジスティック回帰1)
(11) 10. 回帰分析(ロジスティック回帰2)
(12) 11. 回帰分析(リッジ、ラッソ回帰)
(13) 12. 回帰分析(多項式回帰)