読者です 読者をやめる 読者になる 読者になる

ぺーぺーSEのブログ

備忘録・メモ用サイト。

Pythonでロジスティック回帰

Python NumPy 機械学習 ロジスティック回帰 scikit-learn

分類問題をPythonとロジスティック回帰で解いてみる。
データセットは下記のIrisデータセットを使用する。

blog.pepese.com

線形回帰(単回帰)は下記。

blog.pepese.com

線形回帰(重回帰)は下記。

blog.pepese.com

ライブラリ、データセットのロードまでは以下。

import numpy as np
from sklearn import datasets
from matplotlib import pyplot as plt
from sklearn import linear_model
%pylab inline --no-import-all

iris = datasets.load_iris()

データの様子が見たい場合は以下。

features = iris.data
target = iris.target
target_names = iris.target_names
labels = target_names[target]

setosa_petal_length = features[labels == 'setosa', 2]
setosa_petal_width = features[labels == 'setosa', 3]
setosa = np.c_[setosa_petal_length, setosa_petal_width]
versicolor_petal_length = features[labels == 'versicolor', 2]
versicolor_petal_width = features[labels == 'versicolor', 3]
versicolor = np.c_[versicolor_petal_length, versicolor_petal_width]
virginica_petal_length = features[labels == 'virginica', 2]
virginica_petal_width = features[labels == 'virginica', 3]
virginica = np.c_[virginica_petal_length, virginica_petal_width]

plt.scatter(setosa[:, 0], setosa[:, 1], color='red')
plt.scatter(versicolor[:, 0], versicolor[:, 1], color='blue')
plt.scatter(virginica[:, 0], virginica[:, 1], color='green')

plt.show()

ここでは、花びらの長さ(petal length)と花びらの幅(petal width)から、3種のアヤメ(setosa、versicolor、virginica)を分類する。

ライブラリで解く

scikit-learnで解く

sklearn.linear_model.logistic.LogisticRegression」を使用して解く。

X = iris.data[:, [2, 3]] # irisのデータセットの第3, 4カラム(花びらの長さと幅)
y = iris.target # irisのそれぞれのデータごとのラベル、アヤメの種類で0、1、2が格納されている

lr = linear_model.LogisticRegression(C=1e5)
# データセットから学習
lr.fit(X, y)

学習したモデル(lr)のpredict関数に花びらの長さ・幅を入力するとアヤメの種類を返してくれる。
predict関数の引数はnp.arrayでもよい。

# 花びらの長さ(x_1と置く)の値域の最小値-0.5、最大値+0.5を求める
x_1_min, x_1_max = X[:, 0].min() - .5, X[:, 0].max() + .5
# 花びらの幅(x_2と置く)の値域の最小値-0.5、最大値+0.5を求める
x_2_min, x_2_max = X[:, 1].min() - .5, X[:, 1].max() + .5
# 上記のx_1、x_2の値域から0.1刻みでプロット点を生成する
ax_1, ax_2 = np.meshgrid(np.arange(x_1_min, x_1_max, 0.1), np.arange(x_2_min, x_2_max, 0.1))

# 上記で作成した(x_1, x_2)プロット点をモデルに代入し、アヤメの種類(0or1or2)を得る
Z = lr.predict(np.c_[ax_1.ravel(), ax_2.ravel()])

# プロット点と同じ行列(行列×列)に整形する
Z = Z.reshape(ax_1.shape)

描画は下記。

# 分類結果の描画
plt.pcolormesh(ax_1, ax_2, Z, cmap=plt.cm.Paired)
# 線形回帰の結果のように直線(曲線)を描画するのではなく、
# グラフ上の点に対してアヤメの種類(0、1、2)によって色を付ける
# 結果、境界が見える


# データセットの描画
features = iris.data
target = iris.target
target_names = iris.target_names
labels = target_names[target]

setosa_petal_length = features[labels == 'setosa', 2]
setosa_petal_width = features[labels == 'setosa', 3]
setosa = np.c_[setosa_petal_length, setosa_petal_width]
versicolor_petal_length = features[labels == 'versicolor', 2]
versicolor_petal_width = features[labels == 'versicolor', 3]
versicolor = np.c_[versicolor_petal_length, versicolor_petal_width]
virginica_petal_length = features[labels == 'virginica', 2]
virginica_petal_width = features[labels == 'virginica', 3]
virginica = np.c_[virginica_petal_length, virginica_petal_width]

plt.scatter(setosa[:, 0], setosa[:, 1], color='red')
plt.scatter(versicolor[:, 0], versicolor[:, 1], color='blue')
plt.scatter(virginica[:, 0], virginica[:, 1], color='green')

plt.show()