- Overview
- Getting Started Guide
- UserGuide
-
References
-
ABEJA Platform CLI
- CONFIG COMMAND
- DATALAKE COMMAND
- DATASET COMMAND
- TRAINING COMMAND
-
MODEL COMMAND
- check-endpoint-image
- check-endpoint-json
- create-deployment
- create-endpoint
- create-model
- create-service
- create-trigger
- create-version
- delete-deployment
- delete-endpoint
- delete-model
- delete-service
- delete-version
- describe-deployments
- describe-endpoints
- describe-models
- describe-service-logs
- describe-services
- describe-versions
- download-versions
- run-local
- run-local-server
- start-service
- stop-service
- submit-run
- update-endpoint
- startapp command
-
ABEJA Platform CLI
- FAQ
- Appendix
Training Handler Function
The training handler function is a function in the code called when performing training. The handler function uses the following syntax structure.
def handler(context):
...
- context : Information like runtime metadata is stored in this parameter.
Hander Arguments
Context
The Context variable contains the following properties is passed as the first argument of the handler function.
Property Name | Description |
---|---|
datasets | Dict with dataset alias as key, dataset ID as value |
Starting with image version 19.04, the type of context-data has been changed to dict. Please refer to the implementation example below for the changes of the actual usage.
Handler Environment Variable
The handler function can use the following environment variables.
Environment Variable Name | Description |
---|---|
ABEJA_ORGANIZATION_ID | Id of the organization |
TRAINING_JOB_DEFINITION_NAME | Name of the training job definition |
TRAINING_JOB_DEFINITION_VERSION | Version number of training job definition |
TRAINING_JOB_ID | Id of the training job |
HANDLER | Path of the handler |
DATASETS | Dataset information |
ABEJA_TRAINING_RESULT_DIR | Relative path to the directory where the result of the training job is stored |
In addition to the above, environment variables that can be specified when creating code version/service/trigger/run can also be used. For more information on user-specifiable environment variables, see here.
Use of training result in the model handler
ABEJA Platform allows you to create model versions based on the results of training.
In order to use the training output in the model handler function, it must be stored in the directory specified by the ABEJA_TRAINING_RESULT_DIR
environment variable at the time of training.
The output of training is extracted to the directory stored in the ABEJA_TRAINING_RESULT_DIR
environment variable when the model handler function is executed.
Handler return value and exception
Return value
No need to return value fro handler function.
Statistics
You can use abeja.train.client.Client#update_statistics
to record statistics.
See the implementation example below for details.
The maximum amount of data that can be included in a single statistic is 1MB. This is based on the size of the data converted from the configured statistics to escaped JSON data.
Example of implementing handler function ( Example using Keras )
An example of using Keras in training handler function.
from keras.models import Sequential
from keras.callbacks import Callback
from abeja.datasets import Client as DatasetClient
from abeja.train.client import Client as TrainClient
from abeja.train.statistics import Statistics as ABEJAStatistics
class Statistics(Callback):
"""cf. https://keras.io/callbacks/"""
def __init__(self):
super(Statistics, self).__init__()
self.client = TrainClient()
def on_epoch_end(self, epoch, logs=None):
epochs = self.params['epochs']
statistics = ABEJAStatistics(num_epochs=epochs, epoch=epoch + 1)
statistics.add_stage(ABEJAStatistics.STAGE_TRAIN, logs['acc'], logs['loss'])
statistics.add_stage(ABEJAStatistics.STAGE_VALIDATION, logs['val_acc'], logs['val_loss'])
self.client.update_statistics(statistics)
def handler(context):
# get dataset_id from context data
dataset_alias = context.datasets # for image 18.10
# dataset_alias = context['datasets'] # for image 19.04
dataset_id = dataset_alias['train'] # set alias specified in console
# create abeja sdk Client
client = DatasetClient()
dataset = client.get_dataset(dataset_id)
for item in dataset.dataset_items.list(prefetch=True):
# YOU CAN ITERATE DATASET
pass
model = Sequential()
# ADD SOME CODE TO FIT TO YOUR CASE
.....
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
# setup for TensorBoard
ABEJA_TRAINING_RESULT_DIR = os.environ.get('ABEJA_TRAINING_RESULT_DIR', '.')
log_path = os.path.join(ABEJA_TRAINING_RESULT_DIR, 'logs')
tensorboard = TensorBoard(log_dir=log_path, histogram_freq=0,
write_graph=True, write_images=False)
# setup for showing training statistics
statistics = Statistics()
# fit and evaluate
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_data=(x_test, y_test),
callbacks=[tensorboard, statistics])
score = model.evaluate(x_test, y_test, verbose=0)
# save model to result directory
model.save(os.path.join(ABEJA_TRAINING_RESULT_DIR, 'model.h5'))