Training Models

Thinc provides a fairly minimalistic approach to training, leaving you in control to write the training loop. The library provides a few utilities for minibatching, hyperparameter scheduling, loss functions and weight initialization, but does not provide abstractions for data loading, progress tracking or hyperparameter optimization.

The training loop

Thinc assumes that your model will be trained using some form of minibatched stochastic gradient descent. On each step of a standard training loop, you’ll loop over batches of your data and call Model.begin_update on the inputs of the batch, which will return a batch of predictions and a backpropagation callback. You’ll then calculate the gradient of the loss with respect to the output, and provide it to the backprop callback which will increment the gradients of the model parameters as a side-effect. You can then pass an optimizer function into the Model.finish_update method to update the weights.

Basic training loopfor i in range(10):
    for X, Y in train_batches:
        Yh, backprop = model.begin_update(X)
        loss, dYh = get_loss_and_gradient(Yh, Y)
        backprop(dYh)
        model.finish_update(optimizer)

You’ll usually want to make some additions to the loop to save out model checkpoints periodically, and to calculate and report progress statistics. Thinc also provides ready access to lower-level details, making it easy to experiment with arbitrary training variations. You can accumulate the gradients over multiple batches before calling the optimizer, call the backprop callback multiple times (or not at all if the update is small), and inject arbitrary code to change or report gradients for particular layers. The implementation is quite transparent, so you’ll find it easy to implement such arbitrary modifications if you need to.

Batching

Thinc implements two batching helpers via the backend object Ops, typically used via model.ops. They should cover the most common batching needs for training and evaluation.

  1. minibatch: Iterate slices from a sequence, optionally shuffled.
  2. multibatch: Minibatch one or more sequences and yield lists with one batch per sequence.
Examplebatches = model.ops.minibatch(128, data, shuffle=True)
batches = model.ops.multibatch(128, train_X, train_Y, shuffle=True)

The batching methods take sequences of data and process them as a stream. They return a SizedGenerator, a simple custom dataclass for generators that has a __len__ and can repeatedly call the generator function. This also means that the batching works nicely with progress bars like tqdm and similar tools out-of-the-box.

With progress barfrom tqdm import tqdm
data = model.ops.multibatch(128, train_X, train_Y, shuffle=True)
for X, Y in tqdm(data, leave=False):    Yh, backprop = model.begin_update(X)

SizedGenerator objects hold a reference to the generator function and call it repeatedly, i.e. every time the sized generator is executed. This also means that the sized generator is never consumed. If you like, you can define it once outside your training loop, and on each iteration, the data will be rebatched and reshuffled.

Option 1for i in range(10):
    for X, Y in model.ops.multibatch(128, train_X, train_Y, shuffle=True):
        # Update the model here
    for X, Y in model.ops.multibatch(128, dev_X, dev_Y):
        # Evaluate the model here
Option 2train_data = model.ops.multibatch(128, train_X, train_Y, shuffle=True)
dev_data = model.ops.multibatch(128, dev_X, dev_Y)
for i in range(10):
    for X, Y in train_data:
        # Update the model here
    for X, Y in dev_data:
        # Evaluate the model here

The minibatch and multibatch methods also support a buffer argument, which may be useful to promote better parallelism. If you’re using an engine that supports asynchronous execution, such as PyTorch or JAX, an unbuffered stream could cause the engine to block unnecessarily. If you think this may be a problem, try setting a higher buffer, e.g. buffer=500, and see if it solves the problem. You could also simply consume the entire generator, by calling list() on it.

Finally, minibatch and multibatch support variable length batching, based on a schedule you can provide as the batch_size argument. Simply pass in an iterable. Variable length batching is non-standard, but we regularly use it for some of spaCy’s models, especially the parser and entity recognizer.

from thinc.api import compounding

batch_size = compounding(1.0, 16.0, 1.001)
train_data = model.ops.multibatch(batch_size, train_X, train_Y, shuffle=True)
config[batch_size]
@schedules = "compounding.v1"
start = 1.0
stop = 16.0
compound = 1.001
Usagefrom thinc.api import Config, registry

config = Config().from_str("./config.cfg")
resolved = registry.resolve(config)
batch_size = resolved["batch_size"]

Evaluation

Thinc does not provide utilities for calculating accuracy scores over either individual samples or whole datasets. In most situations, you will make a loop over batches of your inputs and targets, calculate the accuracy on the batch of data, and then keep a tally of the scores.

def evaluate(model, batch_size, Xs, Ys):
    correct = 0.
    total = 0.
    for X, Y in model.ops.multibatch(batch_size, Xs, Ys):
        correct += (model.predict(X).argmax(axis=0) == Y.argmax(axis=0)).sum()
        total += X.shape[0]
    return correct / total

During evaluation, take care to run your model in a prediction context (as opposed to a training context), by using either the Model.predict method, or by passing the is_train=False flag to Model.__call__. Some layers may behave differently during training and prediction in order to provide regularization. Dropout layers are the most common example.


Loss calculators

When training your Thinc models, the most important loss calculation is not a scalar loss, but rather the gradient of the loss with respect to your model output. That’s the figure you have to pass into the backprop callback. You actually don’t need to calculate the scalar loss at all, although it’s often helpful as a diagnostic statistic.

Thinc provides a few helpers for common loss functions. Each helper is provided as a class, so you can pass in any settings or hyperparameters that your loss might require. The helper class can be used as a callable object, in which case it will return both the scalar loss and the gradient of the loss with respect to the outputs. You can also call the get_grad method to just get the gradients, or the get_loss method to just get the scalar loss.

Examplefrom thinc.api import CategoricalCrossentropy
loss_calc = CategoricalCrossentropy()
grad, loss = loss_calc(guesses, truths)
config.cfg[loss]
@losses = "CategoricalCrossentropy.v1"
normalize = true

Setting learning rate schedules

A common trick for stochastic gradient descent is to vary the learning rate or other hyperparameters over the course of training. Since there are many possible ways to vary the learning rate, Thinc lets you implement hyperparameter schedules as instances of the Schedule class. Thinc also provides a number of popular schedules built-in.

You can use schedules directly, by calling the schedule with the step keyword argument and using it to update hyperparameters in your training loop. Since schedules are particularly common for optimization settings, the Optimizer object accepts floats, lists, iterators, and Schedule instances for most of its parameters. When you call Optimizer.step_schedules, the optimizer will increase its step count and pass it to the schedules. For instance, this is how one creates an instance of the Adam optimizer with a custom learning rate schedule:

Custom learning rate schedulefrom thinc.api import Adam, Schedule

def cycle():
    values = [0.001, 0.01, 0.1]
    all_values = values + list(reversed(values))
    return Schedule("cycle", _cycle_schedule, attrs={"all_values": all_values})

def _cycle_schedule(schedule: Schedule, step: int, **kwargs) -> float:
    all_values = schedule.attrs["all_values"]
    return all_values[step % len(all_values)]

optimizer = Adam(learn_rate=cycle())
assert optimizer.learn_rate(optimizer.step) == 0.001
optimizer.step_schedules()
assert optimizer.learn_rate(optimizer.step) == 0.01
optimizer.step_schedules()
assert optimizer.learn_rate(optimizer.step) == 0.1

You’ll often want to describe your optimization schedules in your configuration file. That’s also very easy: you can use the @thinc.registry.schedules decorator to register your function, and then refer to it in your config as the learn_rate argument of the optimizer. Check out the documentation on config files for more examples.

Registered function@thinc.registry.schedules("cycle.v1")
def cycle(values):
    all_values = values + list(reversed(values))
    return Schedule("cycle", _cycle_schedule, attrs={"all_values": all_values})

def _cycle_schedule(schedule: Schedule, step: int, **kwargs) -> float:
    all_values = schedule.attrs["all_values"]
    return all_values[step % len(all_values)]
config.cfg[optimizer]
@optimizers = "Adam.v1"

[optimizer.learn_rate]
@schedules = "cycle.v1"
values = [0.001, 0.01, 0.1]

Distributed training

We expect to recommend Ray for distributed training. Ray offers a clean and simple API that fits well with Thinc’s model design. Full support is still under development.