読者です 読者をやめる 読者になる 読者になる

ぺーぺーSEのブログ

備忘録・メモ用サイト。

Pythonでロジスティック回帰

分類問題をPythonとロジスティック回帰で解いてみる。
データセットは下記のIrisデータセットを使用する。

blog.pepese.com

線形回帰(単回帰)は下記。

blog.pepese.com

線形回帰(重回帰)は下記。

blog.pepese.com

ライブラリ、データセットのロードまでは以下。

import numpy as np
from sklearn import datasets
from matplotlib import pyplot as plt
from sklearn import linear_model
%pylab inline --no-import-all

iris = datasets.load_iris()

データの様子が見たい場合は以下。

features = iris.data
target = iris.target
target_names = iris.target_names
labels = target_names[target]

setosa_petal_length = features[labels == 'setosa', 2]
setosa_petal_width = features[labels == 'setosa', 3]
setosa = np.c_[setosa_petal_length, setosa_petal_width]
versicolor_petal_length = features[labels == 'versicolor', 2]
versicolor_petal_width = features[labels == 'versicolor', 3]
versicolor = np.c_[versicolor_petal_length, versicolor_petal_width]
virginica_petal_length = features[labels == 'virginica', 2]
virginica_petal_width = features[labels == 'virginica', 3]
virginica = np.c_[virginica_petal_length, virginica_petal_width]

plt.scatter(setosa[:, 0], setosa[:, 1], color='red')
plt.scatter(versicolor[:, 0], versicolor[:, 1], color='blue')
plt.scatter(virginica[:, 0], virginica[:, 1], color='green')

plt.show()

ここでは、花びらの長さ(petal length)と花びらの幅(petal width)から、3種のアヤメ(setosa、versicolor、virginica)を分類する。

ライブラリで解く

scikit-learnで解く

sklearn.linear_model.logistic.LogisticRegression」を使用して解く。

X = iris.data[:, [2, 3]] # irisのデータセットの第3, 4カラム(花びらの長さと幅)
y = iris.target # irisのそれぞれのデータごとのラベル、アヤメの種類で0、1、2が格納されている

lr = linear_model.LogisticRegression(C=1e5)
# データセットから学習
lr.fit(X, y)

学習したモデル(lr)のpredict関数に花びらの長さ・幅を入力するとアヤメの種類を返してくれる。
predict関数の引数はnp.arrayでもよい。

# 花びらの長さ(x_1と置く)の値域の最小値-0.5、最大値+0.5を求める
x_1_min, x_1_max = X[:, 0].min() - .5, X[:, 0].max() + .5
# 花びらの幅(x_2と置く)の値域の最小値-0.5、最大値+0.5を求める
x_2_min, x_2_max = X[:, 1].min() - .5, X[:, 1].max() + .5
# 上記のx_1、x_2の値域から0.1刻みでプロット点を生成する
ax_1, ax_2 = np.meshgrid(np.arange(x_1_min, x_1_max, 0.1), np.arange(x_2_min, x_2_max, 0.1))

# 上記で作成した(x_1, x_2)プロット点をモデルに代入し、アヤメの種類(0or1or2)を得る
Z = lr.predict(np.c_[ax_1.ravel(), ax_2.ravel()])

# プロット点と同じ行列(行列×列)に整形する
Z = Z.reshape(ax_1.shape)

描画は下記。

# 分類結果の描画
plt.pcolormesh(ax_1, ax_2, Z, cmap=plt.cm.Paired)
# 線形回帰の結果のように直線(曲線)を描画するのではなく、
# グラフ上の点に対してアヤメの種類(0、1、2)によって色を付ける
# 結果、境界が見える


# データセットの描画
features = iris.data
target = iris.target
target_names = iris.target_names
labels = target_names[target]

setosa_petal_length = features[labels == 'setosa', 2]
setosa_petal_width = features[labels == 'setosa', 3]
setosa = np.c_[setosa_petal_length, setosa_petal_width]
versicolor_petal_length = features[labels == 'versicolor', 2]
versicolor_petal_width = features[labels == 'versicolor', 3]
versicolor = np.c_[versicolor_petal_length, versicolor_petal_width]
virginica_petal_length = features[labels == 'virginica', 2]
virginica_petal_width = features[labels == 'virginica', 3]
virginica = np.c_[virginica_petal_length, virginica_petal_width]

plt.scatter(setosa[:, 0], setosa[:, 1], color='red')
plt.scatter(versicolor[:, 0], versicolor[:, 1], color='blue')
plt.scatter(virginica[:, 0], virginica[:, 1], color='green')

plt.show()