トレイニングハンドラー関数

トレイニングハンドラー関数は学習を実行する際に呼び出されるコード内の関数です。ハンドラー関数は以下の構文構造を使用します。

def handler(context):
    ...
  • context : 実行時のメタデータなどがこのパラメータに格納されます。

ハンドラーの引数 

コンテキストデータ

ハンドラー関数の第一引数には以下のプロパティをコンテキスト情報として持つ変数が渡されます。

プロパティ名 説明
datasets データセットエイリアスをキー、データセットIDを値に持つdict

イメージバージョン 19.04 から、コンテキストデータの型が dict に変更になりました。 実際の利用方法の変更点については、後述の実装例を参照ください。

ハンドラーの環境変数

ハンドラー関数では以下の環境変数を使用することが可能です。

環境変数名 説明
ABEJA_ORGANIZATION_ID オーガニゼーションのID
TRAINING_JOB_DEFINITION_NAME ジョブ定義の名前
TRAINING_JOB_DEFINITION_VERSION ジョブ定義のバージョン
TRAINING_JOB_ID 学習ジョブのID
HANDLER ハンドラー関数へのパス
DATASETS データセット情報
ABEJA_TRAINING_RESULT_DIR 学習結果格納用のディレクトリへの相対パス

学習結果の保存 

ABEJA Platformでは学習の結果を元にモデルバージョンを作成することができます。

学習時の出力をモデルハンドラー関数で使用するためには、学習時にABEJA_TRAINING_RESULT_DIR環境変数で指定されるディレクトリに格納する必要があります。

学習時に出力した結果は、モデルハンドラー関数実行時にABEJA_TRAINING_RESULT_DIR環境変数に格納されたディレクトリに展開されます。

ハンドラーの返り値と例外

返り値

トレイニングハンドラー関数には返り値は必要ありません。

トレイニングハンドラー関数の実装例 (Kerasを使った実装例)

Kerasを使ったトレイニングハンドラーの実装例です。

from keras.models import Sequential
from keras.callbacks import Callback

from abeja.datasets import Client as DatasetClient
from abeja.train.client import Client as TrainClient
from abeja.train.statistics import Statistics as ABEJAStatistics

class Statistics(Callback):
    """cf. https://keras.io/callbacks/"""
    def __init__(self):
        super(Statistics, self).__init__()
        self.client = TrainClient()

    def on_epoch_end(self, epoch, logs=None):
        epochs = self.params['epochs']
        statistics = ABEJAStatistics(num_epochs=epochs, epoch=epoch + 1)
        statistics.add_stage(ABEJAStatistics.STAGE_TRAIN, logs['acc'], logs['loss'])
        statistics.add_stage(ABEJAStatistics.STAGE_VALIDATION, logs['val_acc'], logs['val_loss'])
        self.client.update_statistics(statistics)

def handler(context):
    # get dataset_id from context data
    dataset_alias = context.datasets       # for image 18.10
    # dataset_alias = context['datasets']  # for image 19.04
    dataset_id = dataset_alias['train']   # set alias specified in console

    # create abeja sdk Client
    client = DatasetClient()
    dataset = client.get_dataset(dataset_id)

    for item in dataset.dataset_items.list(prefetch=True):
        # YOU CAN ITERATE DATASET
        pass

    model = Sequential()
    # ADD SOME CODE TO FIT TO YOUR CASE
    .....
    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.Adadelta(),
                  metrics=['accuracy'])

    # setup for TensorBoard
    ABEJA_TRAINING_RESULT_DIR = os.environ.get('ABEJA_TRAINING_RESULT_DIR', '.')
    log_path = os.path.join(ABEJA_TRAINING_RESULT_DIR, 'logs')

    tensorboard = TensorBoard(log_dir=log_path, histogram_freq=0,
                              write_graph=True, write_images=False)
    # setup for showing training statistics
    statistics = Statistics()

    # fit and evaluate
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1,
              validation_data=(x_test, y_test),
              callbacks=[tensorboard, statistics])
    score = model.evaluate(x_test, y_test, verbose=0)

    # save model to result directory
    model.save(os.path.join(ABEJA_TRAINING_RESULT_DIR, 'model.h5'))