※記事内に商品プロモーションを含むことがあります。
はじめに
前回、KerasのRecurrentレイヤを使った時系列予測を扱った。
Kerasを使ったRNN, GRU, LSTMによる時系列予測
このとき、Reccurent層に入力するデータを下図のように変形していたが、この方法ではデータサイズが約timesteps
倍に増加してしまう。
そこで、Pythonのジェネレータ (generator) を使い、データを呼び出すときに必要なデータだけ変形することで、大容量のデータを扱う場合でもメモリが不足しないようにする。
なお、ジェネレータとは、for文などを使って要素を逐次的に出力できるオブジェクトであり、なおかつ、1要素を取り出そうとする度に処理を行うものである。本記事ではジェネレータに関する知識は必要ないが、詳細を知りたい方は以下の記事を参照のこと。
Pythonのイテレータとジェネレータ - Qiita
環境
ソフトウェア | バージョン |
---|---|
Anaconda3 | 2019.03 |
Python | 3.7.3 |
TensorFlow | 1.13.1 |
keras | 2.2.4 |
NumPy | 1.16.2 |
本記事では、Pythonで以下の通りライブラリをインポートしていることを前提とする。
|
|
Sequentialモデルのgeneratorに関するメソッド
KerasのSequentialモデルには、generatorを使った学習・検証・予測がサポートされている。
学習はfit_generator
メソッド、検証はevaluate_generator
メソッド、予測はpredict_generator
メソッドをそれぞれ用いる。
fit_generator
メソッドとpredict_generator
メソッドについて簡単に解説する。ここでは一部の引数しか記載していないため、全ての引数を知りたい方は以下のページを参考のこと。
Sequentialモデル - Keras Documentation
fit_generatorメソッド
|
|
引数の説明は以下の通り。
generator: 学習データのgeneratorクラス。呼び出す度に(inputs, targets)のタプルを返す。
epochs: エポック数 (int). デフォルト値は1.
validation_data: 検証データのgeneratorクラスまたは(inputs, targets)のタプル(任意)。
shuffle: 各試行の初めにバッチの順番をシャッフルするかどうか。デフォルト値はTrue.
predict_generatorメソッド
|
|
引数の説明は以下の通り。
generator: 説明変数を返すgeneratorクラス。
generatorクラスの作成
Recurrentレイヤに入力するためのデータを生成するgeneratorクラスを実装する。
genaratorクラスは、keras.utils.Sequence()
クラスを基底クラスとする。
また、学習(fit_generator
メソッド)では説明変数と目的変数の両方、予測(predict_generator
メソッド)では説明変数のみ扱うため、それぞれ異なるgeneratorクラスを作る。
学習用generatorクラス
次のReccurentTrainingGenerator
クラスを実装した。
|
|
以下、簡単な解説である。
Kerasの仕様上、Sequence
を継承するクラスは、__len__
, __getitem__
メソッドを備えなければならない。
__len__
メソッドは1エポックで生成するバッチ数を返す。
また、__getitem__
メソッドはReccurentTrainingGenerator
クラスを呼び出す度に実行され、説明変数と目的変数をバッチで返す。
__getitem__
メソッドで返されるbatch_x
とbatch_y
のイメージは以下の図の通りである。
ここで、x1
, x2
, x3
は異なる説明変数であり、括弧内の数字は時刻を示す。
また、batch_size
は5, timesteps
は3, delay
は1である。
delay
が1とは、時刻tまでのデータを用いて、時刻t+1のデータを予測することを意味する。
batch_x
は(バッチサイズ×timesteps×特徴量数)の3次元配列、
batch_y
は(バッチサイズ×1)の2次元配列である。
ただし、実際にはバッチ方向の時系列の並びはシャッフルされる。
学習データを格納したReccurentTrainingGenerator
をfit_generator
メソッドに渡してやればよい。
予測用generatorクラス
次のReccurentPredictingGenerator
クラスを実装した。
|
|
目的変数を出力せず、データをシャッフルする必要がない以外は、ReccurentTrainingGenerator
クラスと同じである。
__getitem__
メソッドでは先程の図のbatch_x
のみ返される。
ただし、バッチ方向の時系列の並びはシャッフルされない。
予測用データを格納したReccurentPredictingGenerator
をpredict_generator
メソッドに渡してやればよい。
記事が長くなったため、generatorの使い方は後編に分けた。
Kerasの時系列予測でgeneratorを使って大容量データを扱う 後編
参考
今回と次回の記事のコードをまとめたものをGithubにおいている。
https://gist.github.com/helve2017/c20d6106a5dab00a8afa942584b60580
Kerasの公式リファレンス。
Sequentialモデル - Keras Documentation
ユーティリティ - Keras Documentation