JavaScriptを有効にしてください

scikit-learnのBaggingClassifierでバギングする

 ·   6 min read

はじめに

scikit-learnには、アンサンブル学習を行うためのBaggingClassifierが実装されている。
本記事では、BaggingClassifierを用いた学習(バギング、ペースティング、ランダムサブスペース、ランダムパッチ)について解説する。

環境

scikit-learn 0.21.3

アンサンブル学習

アンサンブル学習は、複数の予測器(分類器や回帰器など)の予測結果を1つにまとめる手法である。アンサンブル学習には様々な手法があり、代表的なものを以下に示す。

  • バギング:予測器ごとにランダムに選んだデータで学習させる。
  • ブースティング:逐次的に予測器を構築し、直前の予測器の誤差を修正するように学習させる。
  • スタッキング:複数の予測器の結果をまとめる予測器(メタ学習器)を用いる。

上記の手法は、完全に分類できるものではない。例えば、バギングで学習させた予測器の結果を、スタッキングでまとめることも可能である。

本記事ではバギングおよび、それに類似する手法を扱う。
BaggingClassifierを用いると、これらの手法を簡単に実装できる。

バギング (bagging) に近い手法として、ペースティング (pasting)、ランダムサブスペース (random subspace)、ランダムパッチ (random patche) がある。
各手法の違いを下表に示す。

手法 学習インスタンスの選択 特徴量の選択
バギング 重複ありランダムサンプリング 全て選択
ペースティング 重複なしランダムサンプリング 全て選択
ランダムサブスペース 全て選択 重複ありランダムサンプリング
ランダムパッチ 重複なしランダムサンプリング 重複なしランダムサンプリング

バギングでは学習インスタンスのサンプリングが重複ありで行われ、ペースティングでは重複なしで行われる。すなわち、バギングでは、同じ予測器に対して同じ学習インスタンスが複数回選ばれることがあり得る。
一方、ペースティングでは同じ予測器は異なる学習インスタンスしか選ばれない。

ランダムサブスペースでは、特徴量の方向に、重複ありのランダムサンプリングを行う。この手法が有効なのは、以下のような場合である。

  • 元データの特徴量が非常に多い
  • 特徴量の数の2乗や3乗に比例して、予測器の計算負荷が増加する

ランダムパッチでは、学習インスタンスと特徴量の両方に対して、重複なしのランダムサンプリングを行う。

各手法を実装するには、後述のBaggingClassifierのパラメータを下表のように設定する。

bootstrap max_samples bootstrap_features max_features
バギング True 1.0未満 False 1.0
ペースティング False 1.0未満 False 1.0
ランダムサブスペース False 1.0 True 1.0未満
ランダムパッチ False 1.0未満 False 1.0未満

BaggingClassifierクラスについて

前述の通り、scikit-learnに実装されているBaggingClassifierを用いると、バギング等を簡単に実装できる。

1
2
3
4
5
BaggingClassifier(base_estimator=None, n_estimators=10, 
		  max_samples=1.0, max_features=1.0, 
		  bootstrap=True, bootstrap_features=False, 
		  oob_score=False, warm_start=False, 
		  n_jobs=None, random_state=None, verbose=0)

引数

BaggingClassifierの引数の解説は以下の通り。

base_estimator: object or None
バギング等を行う予測器のオブジェクト。Noneの場合は、決定木 (decision tree) となる。デフォルトはNone.

n_estimators: int
予測器の数。デフォルトは10。

max_samples: int or float
個々の予測器に与える最大のインスタンス数。intの場合、最大インスタンス数そのものになる。floatの場合、データXのインスタンス数にmax_samplesを掛けた数になる。デフォルトは1.0。

max_features: int or float
個々の予測器に与える最大の特徴量の数。intの場合、最大特徴量数そのものになる。floatの場合、データXの特徴量数にmax_featuresを掛けた数になる。デフォルトは1.0。

bootstrap: boolean
Trueの場合、個々の予測器に与える学習インスタンスの重複を許す。Falseの場合、学習インスタンスの重複を許さない。デフォルトはTrue.

bootstrap_features: boolean
Trueの場合、個々の予測器に与える特徴量の重複を許す。Falseの場合、特徴量の重複を許さない。デフォルトはFalse。

oob_score: boolean
Trueの場合、OOB (out-of-bag) スコアを計算する。
バギング (bagging) では個々の予測器の学習用インスタンスを重複ありサンプリングするため、
それぞれの学習器に対して、学習に用いられないインスタンスがある。これをOOBインスタンスと呼ぶ。
各予測器のOOBインスタンスに対する予測精度を計算して平均すると、アンサンブル自体の予測精度を検証することができる。これがOOBスコアである。
なお、oob_scoreで計算される「精度」は正解率 (Accuracy) であり、sklearn.metrics.accuracy_scoreと等価である。
デフォルトはFalse。

warm_start: boolian
Trueの場合、fitメソッドによる学習時に、以前の予測器を維持したまま、新たな予測器を追加する。例えば、グリッドサーチで予測器の数を増やしながら精度変化を検証したい場合に、追加の予測モデルのみ学習させれば良いので、計算時間を短縮できる。ただし、2回目以降のfitの前に、set_paramsメソッドでn_estimatorsの数を増やす必要がある。
Falseの場合、fitメソッドを使う度に学習結果がリセットされ、0から学習が行われる。
デフォルトはFalse。

n_jobs: int or None
intで並列計算数を指定する。-1の場合、CPUの全プロセッサを使用する。Noneの場合、プロセッサを1つだけ使用する。デフォルトはNone。

random_state: int, RandomState instance or None
乱数シードを指定する。デフォルトはNone。

verbose: int
学習時と予測時に、状況を表示するか設定する。数値が大きいほど詳細に表示される。
version 0.21.2では、0~2が有効である。デフォルトは0(表示しない)。

変数

BaggingClassifierクラスのメンバ変数は以下の通り。

base_estimator_ : estimator
予測器のオブジェクト。

estimators_ : list of estimators
学習した予測器オブジェクトのリスト。

estimators_samples_ : list of arrays
各予測器の学習に使用されたインスタンスの番号。

estimators_features_ : list of arrays
各予測器の学習に使用された特徴量の番号。

classes_ : array of shape = [n_classes]
予測クラスのラベル

n_classes_ : int or list
予測クラスの数。

oob_score_ : float
OOBスコア

oob_decision_function_ : array of shape = [n_samples, n_classes]
学習データ中のOOBインスタンスが、どのクラスに分類されたかを表す配列。
n_estimatorsの数が小さい場合は、oob_decision_function_にNaNが含まれる場合がある。
oob_scoreがTrueの場合に有効。

メソッド

BaggingClassifierクラスの主なメソッドは以下の通り。

fit(X, y)
Xを説明変数、yを目的変数として学習する。

predict(X)
説明変数Xのクラスを予測する。

predict_proba(X)
Xのインスタンスが各クラスに属する確率を出力する。

predict_log_proba(X)
Xのインスタンスが各クラスに属する確率の対数を出力する。

set_params(**params)
BaggingClassifierオブジェクトのパラメータを変更する。

参考

sklearn.ensemble.BaggingClassifier — scikit-learn 0.24.0 documentation

シェアする

Helve
WRITTEN BY
Helve
関西在住、電機メーカ勤務のエンジニア。X(旧Twitter)で新着記事を配信中です

サイト内検索