PyTorch, TensorFlow & MXNet
Interoperability with machine learning frameworks

Wrapping models from other frameworks is a core use case for Thinc: we want to make it easy for people to write spaCy components using their preferred machine learning solution. We expect a lot of code-bases will have similar requirements. As well as wrapping whole models, Thinc lets you call into an external framework for just part of your model: you can have a model where you use PyTorch just for the transformer layers, using “native” Thinc layers to do fiddly input and output transformations and add on task-specific “heads”, as efficiency is less of a consideration for those parts of the network.

How it works

Thinc uses a special class, Shim, to hold references to external objects. This allows each wrapper space to define a custom type, with whatever attributes and methods are helpful, to assist in managing the communication between Thinc and the external library. The Model class holds shim instances in a separate list, and communicates with the shims about updates, serialization, changes of device, etc.

The wrapper will receive each batch of inputs, convert them into a suitable form for the underlying model instance, and pass them over to the shim, which will manage the actual communication with the model. The output is then passed back into the wrapper, and converted for use in the rest of the network. The equivalent procedure happens during backpropagation. Array conversion is handled via the DLPack standard wherever possible, so that data can be passed between the frameworks without copying the data back to the host device unnecessarily.

FrameworkWrapper layerShimDLPack
PyTorchPyTorchWrapper (code)PyTorchShim (code)
TensorFlowTensorFlowWrapper (code)TensorFlowShim (code) 1
MXNetMXNetWrapper (code)MXNetShim (code)

To see wrapped models in action, check out the following examples:


Integrating models

The PyTorchWrapper and TensorFlowWrapper layers allow you to easily use your predefined models in Thinc, as part or all of your network. For simple models that accept one array as input and return one array as output, all you need to do is create the PyTorch/TensorFlow layer and pass it into the wrapper. The wrapper model will behave like any other Thinc layer.

PyTorch Examplefrom thinc.api import PyTorchWrapper, chain, Linear
import torch.nn

model = chain(
    PyTorchWrapper(torch.nn.Linear(16, 8)), # 🚨 PyTorch goes (nI, nO)!    Linear(4, 8)
)
X = model.ops.alloc2f(1, 16)  # make a dummy batch
model.initialize(X=X)
Y, backprop = model(X, is_train=True)
dX = backprop(Y)
TensorFlow Examplefrom thinc.api import TensorFlowWrapper, chain, Linear
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

model = chain(
    TensorFlowWrapper(Sequential([Dense(8, input_shape=(16,))])),    Linear(4, 8)
)
X = model.ops.alloc2f(1, 16)  # make a dummy batch
model.initialize(X=X)
Y, backprop = model(X, is_train=True)
dX = backprop(Y)

In theory, you can also chain together layers and models written in PyTorch and TensorFlow. However, this is likely a bad idea for actual production use, especially since TensorFlow tends to hog the GPU. It could come in handy during development, though, for instance if you need to port your models from one framework to another.

Frankenmodelfrom thinc.api import PyTorchWrapper, TensorFlowWrapper, chain, Linear
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import torch.nn

model = chain(  # 🚨 probably don't do this in production
    PyTorchWrapper(torch.nn.Linear(16, 8)),    TensorFlowWrapper(Sequential([Dense(4, input_shape=(8,))])),    Linear(2, 4)
)
model.initialize(X=model.ops.alloc2f(1, 16))

For more complex cases, you can control the way data is passed into the wrapper using the convert_inputs and convert_outputs callbacks. Both callbacks have input signatures like normal forward functions and return a tuple with their output and a callback to handle the backward pass. However, the converters will send and receive different data during the backward pass.

ForwardBackward
InputThinc FrameworkFramework Thinc
OutputFramework ThincThinc Framework

To convert arrays, you can use the xp2 and 2xp utility functions which translate to and from numpy and cupy arrays.

FrameworkTo numpy / cupyFrom numpy / cupy
PyTorchxp2torchtorch2xp
TensorFlowxp2tensorflowtensorflow2xp
MXNetxp2mxnetmxnet2xp

convert_inputs function

The ArgsKwargs object is a little dataclass that represents the tuple (args, kwargs). Whatever you put in the ArgsKwargs you return from your convert_inputs function will be passed directly into PyTorch/TensorFlow/etc. In the backward pass, the shim will return a callback with the gradients from PyTorch, in matching positions on another ArgsKwargs object, and you’ll then return an object that matches the original input, to pass the gradient back down the Thinc model.

ArgumentTypeDescription
modelModelThe wrapper layer.
inputsAnyThe input to the layer.
is_trainboolA flag indicating training context.
RETURNSTuple[ArgsKwargs, Callable[[ArgsKwargs], Any]Inputs to the PyTorchShim, and a callback that receives the input gradients from PyTorch and returns the converted gradients.

convert_outputs function

For the output, the converter will receive a tuple that contains the original input (i.e. whatever was passed into convert_inputs, and the output from the PyTorch/TensorFlow/etc. layer. The input is provided because the output may not contain all the information you need to manage the conversion. The convert_output function should return the output object for the rest of the network, converting via the 2xp helpers as necessary, along with the un-convert callback.

The un-convert callback will receive a gradient object from the Thinc layer above it in the network, and will return an ArgsKwargs object that will be passed into torch.autograd.backward and the TensorFlow model and tensorflow.GradientTape.watch respectively.

ArgumentTypeDescription
convert_outputsModelThe wrapper layer.
outputsTuple[Any, Any]A tuple of the original inputs and the PyTorch model’s outputs.
RETURNSTuple[Any, Callable[[Any], ArgsKwargs]]A tuple of the PyTorch outputs, and a callback to un-convert the gradient for PyTorch that takes the output gradients from Thinc and returns the output gradients for PyTorch.

More specific PyTorch layers

Thinc also includes some more specific PyTorch layers for common use-cases. The PyTorchLSTM layer creates and wraps the torch.nn.LSTM class, making creation particularly easy. The PyTorchRNNWrapper provides a little more flexibility, allowing you to pass in a custom sequence model that has the same inputs and output behavior as a torch.nn.RNN object.

Avoiding memory contention (experimental)

If you use the PyTorchWrapper for part of your network while using Thinc’s layers for other parts, you may find yourself running out of GPU memory unexpectedly. This can occur because both PyTorch and cupy reserve their own internal memory pools, and the two libraries do not communicate with each other. When PyTorch needs more memory, it can only ask the device – so you may get an out-of-memory error even though cupy’s pool has plenty of spare memory available.

The best solution to this problem is to reroute the memory requests so that only one library is in charge. Specifically, cupy offers a cupy.cuda.set_allocator function, which should allow a custom allocator to be created that requests its memory via PyTorch. Thinc provides a handy shortcut for this via the use_pytorch_for_gpu_memory helper function. We’re hoping to add a helper for TensorFlow in the future once DLPack is supported in TensorFlow.