This module contains the classes which represent the model trainers. Each class defines all parameters which are required to train and test the model (during training).
Each model has its own trainer class example:
Those model trainers require different environment variables to be present, but almost all of them have fall-back default values. The variables themselves are quite similar from model to model. TrainMlm.java is the only one which has all variables commented.
IntelliJ run configurations for each of those classes could be found in the .run folder. The environment variables which need to be provided explicitly:
In order to successfully run the training of Masked Language (Encoder) and Auto-Regressive (Generative) models, you need to provide
your own IDataProvider
instance and set it using the transformer.setDataProvider()
method. The current implementation uses
WikiArticlesContentProvider
, which relies on a corresponding MongoDB running with a Wikipedia dump. That was the primary data source
for training, but because setting it up is quite tedious, the fastest option is to implement your own IDataProvider
as a lambda or
anonymous class that fetches the data you need. Please refer to the method org.tarik.train.TrainMlm#getWikiArticlesContentProvider
for
guidance on fetching the data in chunks. The implementation of getPassages(Function<List<String>, Boolean> isLimitReachedFunction)
method is very important. This one is called depending on how much data you provide to the model. The latter expects you to provide at
least BATCH_SIZE
(default=128) sequences (passages) of tokens for one iteration. There is also MIN_SEQUENCE_UTILIZATION
(default=50%)
variable which tells how many tokens in % from the sequence length (hardcoded as 256 for now) each sequence from the provider should
contain at least, so that it could be accepted by the model and added to the batchedTokenSentences
variable. If your provider gives
back per one call less than BATCH_SIZE
sequences, the model will call it so many times, till it gets the whole batch. In a similar way,
if the sequences which come from provider contain less than 50% of tokens (the rest will always be masked), the model skips (ignores) them
and calls again getPassages(...)
as many times as needed in order to fill the batch. This could be the source of eternal loop, if the
provider always gives back some data and gets never exhausted (e.g. is a simple mock).
If you want to see more detailed logs - you could change the LOG_LEVEL environment variable value to DEBUG (it’s INFO by default). Also,
you can turn on the javacpp debug (if anything’s wrong with C++ - related part) using the following code:
System.setProperty("org.bytedeco.javacpp.logger.debug", "true")
. There’s also SameDiff logging which allows to see details of each
running operation. You can turn it on using this line : sd.enableDebugMode()
;
In order to create a docker image for Linux, the following IntelliJ run configurations could be used:
Docker must be installed on the PC, where those run configurations are executed. Each of those run configurations creates the image, installs it into the local registry and exports it as a TAR archive locally. This archive can be later copied anywhere and used in order to start the container.
All Docker-related files are located in the docker folder:
To start the Docker container, you can use the docker run
command.
There’s a very valuable listener class in the model class itself - CustomListener. This one has different methods which allow you to see
what happens during the execution. For example, the method
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs)
allows you to see
the results and params of each operation after it’s been executed. It will allow you to understand if the model is training at all or if
it’s stuck somewhere. Because it’s quite hard to debug SameDiff directly, this listener is a good utility to see what happens during
training. It also has methods which allow to see which weights are updated, as well as to see what happens before each operation is executed.
You can use that listener almost always, even if I want to check that the intermediate state of your operations is as expected.