Backends & Math

All Thinc models have a reference to an Ops instance, that provides access to memory allocation and mathematical routines. The Model.ops instance also keeps track of state and settings, so that you can have different models in your network executing on different devices or delegating to different underlying libraries.

Each Ops instance holds a reference to a numpy-like module (numpy or cupy), which you can access at Model.ops.xp. This is enough to make most layers work on both CPU and GPU devices. Additionally, there are several routines that we have implemented as methods on the Ops object, so that specialized versions can be called for different backends. You can also create your own Ops subclasses with specialized routines for your layers, and use the set_current_ops function to change the default.

BackendCPUGPUTPUDescription
NumpyOpsExecute via numpy, blis (optional) and custom Cython.
CupyOpsExecute via cupy and custom CUDA.

Ops class

The Ops class is typically not used directly but via NumpyOps or CupyOps, which are subclasses of Ops and implement a more efficient subset of the methods. You also have access to the ops via the Model.ops attribute. The documented methods below list which backends provide optimized and more efficient versions (indicated by ), and which use the default implementation. Thinc also provides various helper functions for getting and setting different backends.

Examplefrom thinc.api import Linear, get_ops, use_ops

model = Linear(4, 2)
X = model.ops.alloc2f(10, 2)
blis_ops = get_ops("numpy", use_blis=True)
use_ops(blis_ops)

Attributes

NameTypeDescription
namestrClass attribute: Backend name, "numpy" or "cupy".
xpXpClass attribute: numpy or cupy.
device_typestrThe device type to use, if available for the given backend: "cpu", "gpu" or "tpu".
device_idintThe device ID to use, if available for the given backend.

Ops.__init__ method

ArgumentTypeDescription
device_typestrThe device type to use, if available for the given backend: "cpu", "gpu" or "tpu".
device_idintThe device ID to use, if available for the given backend.
keyword-only
use_blisboolNumpyOps: Use blis for single-threaded matrix multiplication.

Ops.minibatch method

  • default:
  • numpy: default
  • cupy: default

Iterate slices from a sequence, optionally shuffled. Slices may be either views or copies of the underlying data. Supports the batchable data types Pairs, Ragged and Padded, as well as arrays, lists and tuples. The size argument may be either an integer, or a sequence of integers. If a sequence, a new size is drawn before every output. If shuffle is True, shuffled batches are produced by first generating an index array, shuffling it, and then using it to slice into the sequence. An internal queue of buffer items is accumulated before being each output. Buffering is useful for some devices, to allow the network to run asynchronously without blocking on every batch.

The method returns a SizedGenerator that exposes a __len__ and is rebatched and reshuffled every time it’s executed, allowing you to move the batching outside of the training loop.

Examplebatches = model.ops.minibatch(128, train_X, shuffle=True)
ArgumentTypeDescription
sizeUnion[int, Generator]The batch size(s).
sequenceBatchableThe sequence to batch.
keyword-only
shuffleboolWhether to shuffle the items.
bufferintNumber of items to accumulate before each output. Defaults to 1.
RETURNSSizedGeneratorThe batched items.

Ops.multibatch method

  • default:
  • numpy: default
  • cupy: default

Minibatch one or more sequences of data, and return lists with one batch per sequence. Otherwise identical to Ops.minibatch.

Examplebatches = model.ops.multibatch(128, train_X, train_Y, shuffle=True)
ArgumentTypeDescription
sizeUnion[int, Generator]The batch size(s).
sequenceBatchableThe sequence to batch.
*otherBatchableThe other sequences to batch.
keyword-only
shuffleboolWhether to shuffle the items.
bufferintNumber of items to accumulate before each output. Defaults to 1.
RETURNSSizedGeneratorThe batched items.

Ops.seq2col method

  • default: (nW=1 only)
  • numpy:
  • cupy:

Given an (M, N) sequence of vectors, return an (M, N*(nW*2+1)) sequence. The new sequence is constructed by concatenating nW preceding and succeeding vectors onto each column in the sequence, to extract a window of features.

ArgumentTypeDescription
seqFloats2dThe original sequence.
nWintThe window size.
keyword-only
lengthsOptional[Ints1d]Sequence lengths, introduces padding around sequences.
RETURNSFloats2dThe created sequence containing preceding and succeeding vectors.

Ops.backprop_seq2col method

  • default: (nW=1 only)
  • numpy:
  • cupy:

The reverse/backward operation of the seq2col function: calculate the gradient of the original (M, N) sequence, as a function of the gradient of the output (M, N*(nW*2+1)) sequence.

ArgumentTypeDescription
dYFloats2dGradient of the output sequence.
nWintThe window size.
keyword-only
lengthsOptional[Ints1d]Sequence lengths, introduces padding around sequences.
RETURNSFloats2dGradient of the original sequence.

Ops.gemm method

  • default:
  • numpy:
  • cupy:

Perform General Matrix Multiplication (GeMM) and optionally store the result in the specified output variable.

ArgumentTypeDescription
xFloats2dFirst array.
yFloats2dSecond array.
outOptional[Floats2d]Variable to store the result of the matrix multiplication in.
trans1boolWhether or not to transpose array x.
trans2boolWhether or not to transpose array y.
RETURNSFloats2dThe result of the matrix multiplication.

Ops.affine method

  • default:
  • numpy: default
  • cupy: default

Apply a weights layer and a bias to some inputs, i.e. Y = X @ W.T + b.

ArgumentTypeDescription
XFloats2dThe inputs.
WFloats2dThe weights.
bFloats1dThe bias vector.
RETURNSFloats2dThe output.

Ops.flatten method

  • default:
  • numpy: default
  • cupy: default

Flatten a list of arrays into one large array.

ArgumentTypeDescription
XSequence[ArrayXd]The original list of arrays.
dtypeOptional[DTypes]The data type to cast the resulting array in.
padintThe number of zeros to add as padding to X (default 0).
ndim_if_emptyintThe dimension of the output result if X is None or empty.
RETURNSArrayXdOne large array storing all original information.

Ops.unflatten method

  • default:
  • numpy: default
  • cupy: default

The reverse/backward operation of the flatten function: unflatten a large array into a list of arrays according to the given lengths.

ArgumentTypeDescription
XArrayXdThe flattened array.
lengthsInts1dThe lengths of the original arrays before they were flattened.
padintThe padding that was applied during the flatten step (default 0).
RETURNSList[ArrayXd]A list of arrays storing the same information as the flattened array.

Ops.pad method

  • default:
  • numpy: default
  • cupy: default

Perform padding on a list of arrays so that they each have the same length, by taking the maximum dimension across each axis. This only works on non-empty sequences with the same ndim and dtype.

ArgumentTypeDescription
seqsList[Array2d]The sequences to pad.
round_tointRound the length to nearest bucket (helps on GPU, to make similar array sizes). Defaults to 1.
RETURNSArray3dThe padded sequences, stored in one array.

Ops.unpad method

  • default:
  • numpy: default
  • cupy: default

The reverse/backward operation of the pad function: transform an array back into a list of arrays, each with their original length.

ArgumentTypeDescription
paddedArrayXdThe padded sequences, stored in one array.
lengthsList[int]The original lengths of the unpadded sequences.
RETURNSList[ArrayXd]The unpadded sequences.

Ops.list2padded method

  • default:
  • numpy: default
  • cupy: default

Pack a sequence of two-dimensional arrays into a Padded datatype.

ArgumentTypeDescription
seqsList[Array2d]The sequences to pack.
RETURNSPaddedThe packed arrays.

Ops.padded2list method

  • default:
  • numpy: default
  • cupy: default

Unpack a Padded datatype to a list of two-dimensional arrays.

ArgumentTypeDescription
paddedPaddedThe object to unpack.
RETURNSList[Array2d]The unpacked sequences.

Ops.get_dropout_mask method

  • default:
  • numpy: default
  • cupy: default

Create a random mask for applying dropout, with a certain percent of the mask (defined by drop) will contain zeros. The neurons at those positions will be deactivated during training, resulting in a more robust network and less overfitting.

ArgumentTypeDescription
shapeShapeThe input shape.
dropOptional[float]The dropout rate.
RETURNSFloatsA mask specifying a 0 where a neuron should be deactivated.

Ops.alloc method

  • default:
  • numpy:
  • cupy: default

Allocate an array of a certain shape. If possible, you should always use the type-specific methods listed below, as they make the code more readable and allow more sophisticated static type checking of the inputs and outputs.

ArgumentTypeDescription
shapeShapeThe shape.
keyword-only
dtypeDTypesThe data type (default: float32).
zerosboolFill the array with zeros (default: True).
RETURNSArrayXdAn array of the correct shape and data type.

Ops.cblas method

  • default:
  • numpy:
  • cupy:

Get a table of C BLAS functions usable in Cython cdef nogil functions. This method does not take any arguments.

Ops.to_numpy method

  • default:
  • numpy: default
  • cupy:

Convert the array to a numpy array.

ArgumentTypeDescription
dataArrayXdThe array.
keyword-only
byte_orderOptional[str]The new byte order, None preserves the current byte order (default: None).
RETURNSnumpy.ndarrayA numpy array with the specified byte order.

Type-specific methods

  • Floats: Ops.alloc_f, Ops.alloc1f, Ops.alloc2f, Ops.alloc3f, Ops.alloc4f
  • Ints: Ops.alloc_i, Ops.alloc1i, Ops.alloc2i, Ops.alloc3i, Ops.alloc4i

Shortcuts to allocate an array of a certain shape and data type (f refers to float32 and i to int32). For instance, Ops.alloc2f will allocate an two-dimensional array of floats.

ExampleX = model.ops.alloc2f(10, 2)  # Floats2d
Y = model.ops.alloc1i(4)  # Ints1d
ArgumentTypeDescription
*shapeintThe shape, one positional argument per dimension.
keyword-only
dtypeDTypesInt / DTypesFloatThe data type (float type for float methods and int type for int methods).
zerosboolFill the array with zeros (default: True).
RETURNSArrayXdAn array of the correct shape and data type.

Ops.reshape method

  • default:
  • numpy: default
  • cupy: default

Reshape an array and return an array containing the same data with the given shape. If possible, you should always use the type-specific methods listed below, as they make the code more readable and allow more sophisticated static type checking of the inputs and outputs.

ArgumentTypeDescription
arrayArrayXdThe array to reshape.
shapeShapeThe shape.
RETURNSArrayXdThe reshaped array.

Type-specific methods

  • Floats: Ops.reshape_f, Ops.reshape1f, Ops.reshape2f, Ops.reshape3f, Ops.reshape4f
  • Ints: Ops.reshape_i, Ops.reshape1i, Ops.reshape2i, Ops.reshape3i, Ops.reshape4i

Shortcuts to reshape an array of a certain shape and data type (f refers to float32 and i to int32). For instance, reshape2f can be used to reshape an array of floats to a 2d-array of floats.

ExampleX = model.ops.reshape2f(X, 10, 2)  # Floats2d
Y = model.ops.reshape1i(Y, 4)  # Ints1d
ArgumentTypeDescription
arrayArrayXdThe array to reshape (of the same data type).
*shapeintThe shape, one positional argument per dimension.
RETURNSArrayXdThe reshaped array (of the same data type as the input array).

Ops.asarray method

  • default:
  • numpy:
  • cupy:

Ensure a given array is of the correct type, e.g. numpy.ndarray for NumpyOps or cupy.ndarray for CupyOps. If possible, you should always use the type-specific methods listed below, as they make the code more readable and allow more sophisticated static type checking of the inputs and outputs.

ArgumentTypeDescription
dataUnion[ArrayXd, Sequence[ArrayXd], Sequence[int]]The original array.
keyword-only
dtypeOptional[DTypes]The data type
RETURNSArrayXdThe array transformed to the correct type.

Type-specific methods

  • Floats: Ops.asarray_f, Ops.asarray1f, Ops.asarray2f, Ops.asarray3f, Ops.asarray4f
  • Ints: Ops.asarray_i, Ops.asarray1i, Ops.asarray2i, Ops.asarray3i, Ops.asarray4i

Shortcuts for specific dimensions and data types (f refers to float32 and i to int32). For instance, Ops.asarray2f will return a two-dimensional array of floats.

ExampleX = model.ops.asarray2f(X, 10, 2)  # Floats2d
Y = model.ops.asarray1i(Y, 4)  # Ints1d
ArgumentTypeDescription
*shapeintThe shape, one positional argument per dimension.
keyword-only
dtypeDTypesInt / DTypesFloatThe data type (float type for float methods and int type for int methods).
RETURNSArrayXdAn array of the correct shape and data type, filled with zeros.

Ops.as_contig method

  • default:
  • numpy: default
  • cupy: default

Allow the backend to make a contiguous copy of an array. Implementations of Ops do not have to make a copy or make it contiguous if that would not improve efficiency for the execution engine.

ArgumentTypeDescription
dataArrayXdThe array.
keyword-only
dtypeOptional[DTypes]The data type
RETURNSArrayXdAn array with the same contents as the input.

Ops.unzip method

  • default:
  • numpy: default
  • cupy: default

Unzip a tuple of two arrays, transform them with asarray and return them as two separate arrays.

ArgumentTypeDescription
dataTuple[ArrayXd, ArrayXd]The tuple of two arrays.
RETURNSTuple[ArrayXd, ArrayXd]The two arrays, transformed with asarray.

Ops.sigmoid method

  • default:
  • numpy: default
  • cupy: default

Calculate the sigmoid function.

ArgumentTypeDescription
XFloatsXdThe input values.
keyword-only
inplaceboolIf True, the array is modified in place.
RETURNSFloatsXdThe output values, i.e. S(X).

Ops.dsigmoid method

  • default:
  • numpy: default
  • cupy: default

Calculate the derivative of the sigmoid function.

ArgumentTypeDescription
YFloatsXdThe input values.
keyword-only
inplaceboolIf True, the array is modified in place.
RETURNSFloatsXdThe output values, i.e. dS(Y).

Ops.dtanh method

  • default:
  • numpy: default
  • cupy: default

Calculate the derivative of the tanh function.

ArgumentTypeDescription
YFloatsXdThe input values.
keyword-only
inplaceboolIf True, the array is modified in place.
RETURNSFloatsXdThe output values, i.e. dtanh(Y).

Ops.softmax method

  • default:
  • numpy: default
  • cupy: default

Calculate the softmax function. The resulting array will sum up to 1.

ArgumentTypeDescription
xFloatsXdThe input values.
keyword-only
inplaceboolIf True, the array is modified in place.
axisintThe dimension to normalize over.
temperaturefloatThe value to divide the unnormalized probabilities by.
RETURNSFloatsXdThe normalized output values.

Ops.backprop_softmax method

  • default:
  • numpy: default
  • cupy: default
ArgumentTypeDescription
YFloatsXdOutput array.
dYFloatsXdGradients of the output array.
keyword-only
axisintThe dimension that was normalized over.
temperaturefloatThe value to divide the unnormalized probabilities by.
RETURNSFloatsXdThe gradients of the input array.

Ops.softmax_sequences method

  • default:
  • numpy: default
  • cupy: default
ArgumentTypeDescription
XsFloats2dAn 2d array of input sequences.
lengthsInts1dThe lengths of the input sequences.
keyword-only
inplaceboolIf True, the array is modified in place.
axisintThe dimension to normalize over.
RETURNSFloats2dThe normalized output values.

Ops.backprop_softmax_sequences method

  • default:
  • numpy: default
  • cupy: default

The reverse/backward operation of the softmax function.

ArgumentTypeDescription
dYFloats2dGradients of the output array.
YFloats2dOutput array.
lengthsInts1dThe lengths of the input sequences.
RETURNSFloats2dThe gradients of the input sequences.

Ops.recurrent_lstm method

  • default:
  • numpy: default
  • cupy: default

Encode a padded batch of inputs into a padded batch of outputs using an LSTM.

ArgumentTypeDescription
WFloats2dThe weights, shaped (nO * 4, nO + nI).
bFloats1dThe bias vector, shaped (nO * 4,).
h_initFloats1dInitial value for the previous hidden vector.
c_initFloats1dInitial value for the previous cell state.
inputsFloats3dA batch of inputs, shaped (nL, nB, nI), where nL is the sequence length and nB is the batch size.
is_trainboolWhether the model is running in a training context.
RETURNSTuple[Floats3d, Tuple[Floats3d, Floats3d, Floats3d]]A tuple consisting of the outputs and the intermediate activations required for the backward pass. The outputs are shaped (nL, nB, nO).

Ops.backprop_recurrent_lstm method

  • default:
  • numpy: default
  • cupy: default

Compute the gradients for the recurrent_lstm operation via backpropagation.

ArgumentTypeDescription
dYFloats3dThe gradient w.r.t. the outputs.
fwd_stateTuple[Floats3d, Floats3d, Floats3d]The tuple of gates, cells and inputs, returned by the forward pass.
paramsTuple[Floats2d, Floats1d]A tuple of the weights and biases.
RETURNSTuple[Floats3d, Tuple[Floats2d, Floats1d, Floats1d, Floats1d]]The gradients for the inputs and parameters (the weights, biases, initial hiddens and initial cells).

Ops.maxout method

  • default:
  • numpy:
  • cupy:
ArgumentTypeDescription
XFloats3dThe inputs.
RETURNSTuple[Floats2d, Ints2d]The outputs and an array indicating which elements in the final axis were used.

Ops.backprop_maxout method

  • default:
  • numpy:
  • cupy:
ArgumentTypeDescription
dYFloats2dGradients of the output array.
whichInts2dThe positions selected in the forward pass.
PintThe size of the final dimension.
RETURNSFloats3dThe gradient of the inputs.

Ops.relu method

  • default:
  • numpy:
  • cupy:
ArgumentTypeDescription
XFloats2dThe inputs.
keyword-only
inplaceboolIf True, the array is modified in place.
RETURNSFloats2dThe outputs.

Ops.backprop_relu method

  • default:
  • numpy:
  • cupy:
ArgumentTypeDescription
dYFloats2dGradients of the output array.
YFloats2dThe output from the forward pass.
keyword-only
inplaceboolIf True, the array is modified in place.
RETURNSFloats2dThe gradient of the input.

Ops.mish method

  • default:
  • numpy:
  • cupy:

Compute the Mish activation (Misra, 2019).

ArgumentTypeDescription
XFloatsXdThe inputs.
thresholdfloatMaximum value at which to apply the activation.
inplaceboolApply Mish to X in-place.
RETURNSFloatsXdThe outputs.

Ops.backprop_mish method

  • default:
  • numpy:
  • cupy:

Backpropagate the Mish activation (Misra, 2019).

ArgumentTypeDescription
dYFloatsXdGradients of the output array.
XFloatsXdThe inputs to the forward pass.
thresholdfloatThreshold from the forward pass.
inplaceboolApply Mish backprop to dY in-place.
RETURNSFloatsXdThe gradient of the input.

Ops.swish method

  • default:
  • numpy:
  • cupy:

Swish (Ramachandran et al., 2017) is a self-gating non-monotonic activation function similar to the GELU activation: whereas GELU uses the CDF of the Gaussian distribution Φ for self-gating x * Φ(x), Swish uses the logistic CDF x * σ(x). Sometimes referred to as “SiLU” for “Sigmoid Linear Unit”.

ArgumentTypeDescription
XFloatsXdThe inputs.
inplaceboolIf True, the array is modified in place.
RETURNSFloatsXdThe outputs.

Ops.backprop_swish method

  • default:
  • numpy:
  • cupy:

Backpropagate the Swish activation (Ramachandran et al., 2017).

ArgumentTypeDescription
dYFloatsXdGradients of the output array.
XFloatsXdThe inputs to the forward pass.
YFloatsXdThe outputs to the forward pass.
inplaceboolIf True, the dY array is modified in place.
RETURNSFloatsXdThe gradient of the input.

Ops.dish methodNew: v8.1.1

  • default:
  • numpy:
  • cupy:

Dish or “Daniël’s Swish-like activation” is an activation function with a non-monotinic shape similar to GELU, Swish and Mish. However, Dish does not rely on elementary functions like exp or erf, making it much faster to compute in most cases.

ArgumentTypeDescription
XFloatsXdThe inputs.
inplaceboolIf True, the array is modified in place.
RETURNSFloatsXdThe outputs.

Ops.backprop_dish methodNew: v8.1.1

  • default:
  • numpy:
  • cupy:

Backpropagate the Dish activation.

ArgumentTypeDescription
dYFloatsXdGradients of the output array.
XFloatsXdThe inputs to the forward pass.
inplaceboolIf True, the dY array is modified in place.
RETURNSFloatsXdThe gradient of the input.

Ops.gelu method

  • default:
  • numpy:
  • cupy:

GELU or “Gaussian Error Linear Unit” (Hendrycks and Gimpel, 2016) is a self-gating non-monotonic activation function similar to the Swish activation: whereas GELU uses the CDF of the Gaussian distribution Φ for self-gating x * Φ(x) the Swish activation uses the logistic CDF σ and computes x * σ(x). Various approximations exist, but thinc implements the exact GELU. The use of GELU is popular within transformer feed-forward blocks.

ArgumentTypeDescription
XFloatsXdThe inputs.
inplaceboolIf True, the array is modified in place.
RETURNSFloatsXdThe outputs.

Ops.backprop_gelu method

  • default:
  • numpy:
  • cupy:

Backpropagate the GELU activation (Hendrycks and Gimpel, 2016).

ArgumentTypeDescription
dYFloatsXdGradients of the output array.
XFloatsXdThe inputs to the forward pass.
inplaceboolIf True, the dY array is modified in place.
RETURNSFloatsXdThe gradient of the input.

Ops.relu_k method

  • default:
  • numpy:
  • cupy:

ReLU activation function with the maximum value clipped at k. A common choice is k=6 introduced for convolutional deep belief networks (Krizhevsky, 2010). The resulting function relu6 is commonly used in low-precision scenarios.

ArgumentTypeDescription
XFloatsXdThe inputs.
inplaceboolIf True, the array is modified in place.
kfloatMaximum value (default: 6.0).
RETURNSFloatsXdThe outputs.

Ops.backprop_relu_k method

  • default:
  • numpy:
  • cupy:

Backpropagate the ReLU-k activation.

ArgumentTypeDescription
dYFloatsXdGradients of the output array.
XFloatsXdThe inputs to the forward pass.
inplaceboolIf True, the dY array is modified in place.
RETURNSFloatsXdThe gradient of the input.

Ops.hard_sigmoid method

  • default:
  • numpy:
  • cupy:

The hard sigmoid activation function is a fast linear approximation of the sigmoid activation, defined as max(0, min(1, x * 0.2 + 0.5)).

ArgumentTypeDescription
XFloatsXdThe inputs.
inplaceboolIf True, the array is modified in place.
RETURNSFloatsXdThe outputs.

Ops.backprop_hard_sigmoid method

  • default:
  • numpy:
  • cupy:

Backpropagate the hard sigmoid activation.

ArgumentTypeDescription
dYFloatsXdGradients of the output array.
XFloatsXdThe inputs to the forward pass.
inplaceboolIf True, the dY array is modified in place.
RETURNSFloatsXdThe gradient of the input.

Ops.hard_tanh method

  • default:
  • numpy:
  • cupy:

The hard tanh activation function is a fast linear approximation of tanh, defined as max(-1, min(1, x)).

ArgumentTypeDescription
XFloatsXdThe inputs.
inplaceboolIf True, the array is modified in place.
RETURNSFloatsXdThe outputs.

Ops.backprop_hard_tanh method

  • default:
  • numpy:
  • cupy:

Backpropagate the hard tanh activation.

ArgumentTypeDescription
dYFloatsXdGradients of the output array.
XFloatsXdThe inputs to the forward pass.
inplaceboolIf True, the dY array is modified in place.
RETURNSFloatsXdThe gradient of the input.

Ops.clipped_linear method

  • default:
  • numpy:
  • cupy:

Flexible clipped linear activation function of the form max(min_value, min(max_value, x * slope + offset)). It is used to implement the relu_k, hard_sigmoid, and hard_tanh methods.

ArgumentTypeDescription
XFloatsXdThe inputs.
inplaceboolIf True, the array is modified in place.
slopefloatThe slope of the linear function: input * slope.
offsetfloatThe offset or intercept of the linear function: input * slope + offset.
min_valfloatMinimum value to clip to.
max_valfloatMaximum value to clip to.
RETURNSFloatsXdThe outputs.

Ops.backprop_clipped_linear method

  • default:
  • numpy:
  • cupy:

Backpropagate the clipped linear activation.

ArgumentTypeDescription
dYFloatsXdGradients of the output array.
XFloatsXdThe inputs to the forward pass.
slopefloatThe slope of the linear function: input * slope.
offsetfloatThe offset or intercept of the linear function: input * slope + offset.
min_valfloatMinimum value to clip to.
max_valfloatMaximum value to clip to.
inplaceboolIf True, the dY array is modified in place.
RETURNSFloatsXdThe gradient of the input.

Ops.hard_swish method

  • default:
  • numpy:
  • cupy:

The hard Swish activation function is a fast linear approximation of Swish: x * hard_sigmoid(x).

ArgumentTypeDescription
XFloatsXdThe inputs.
inplaceboolIf True, the array is modified in place.
RETURNSFloatsXdThe outputs.

Ops.backprop_hard_swish method

  • default:
  • numpy:
  • cupy:

Backpropagate the hard Swish activation.

ArgumentTypeDescription
dYFloatsXdGradients of the output array.
XFloatsXdThe inputs to the forward pass.
inplaceboolIf True, the dY array is modified in place.
RETURNSFloatsXdThe gradient of the input.

Ops.hard_swish_mobilenet method

  • default:
  • numpy:
  • cupy:

A variant of the fast hard Swish activation function used in MobileNetV3 (Howard et al., 2019), defined as x * (relu6(x + 3) / 6).

ArgumentTypeDescription
XFloatsXdThe inputs.
inplaceboolIf True, the array is modified in place.
RETURNSFloatsXdThe outputs.

Ops.backprop_hard_swish_mobilenet method

  • default:
  • numpy:
  • cupy:

Backpropagate the hard Swish MobileNet activation.

ArgumentTypeDescription
dYFloatsXdGradients of the output array.
XFloatsXdThe inputs to the forward pass.
inplaceboolIf True, the dY array is modified in place.
RETURNSFloatsXdThe gradient of the input.

Ops.reduce_first method

  • default:
  • numpy: default
  • cupy: default

Perform sequence-wise first pooling for data in the ragged format.

  • Zero-length sequences are not allowed. A ValueError is raised if any element in lengths is zero.
  • Batch and hidden dimensions can have a size of zero. In these cases the corresponding dimensions in the output also have a size of zero.
ArgumentTypeDescription
XFloats2dThe concatenated sequences.
lengthsInts1dThe sequence lengths.
RETURNSTuple[Floats2d,Ints1d]The first vector of each sequence and the sequence start/end indices.

Ops.backprop_reduce_first method

  • default:
  • numpy: default
  • cupy: default

Backpropagate the reduce_first operation.

ArgumentTypeDescription
d_firstsFloats2dThe gradient of the outputs.
starts_endsInts1dThe sequence start/end indices.
RETURNSFloats2dThe gradient of the concatenated sequences.

Ops.reduce_last method

  • default:
  • numpy: default
  • cupy: default

Perform sequence-wise last pooling for data in the ragged format.

  • Zero-length sequences are not allowed. A ValueError is raised if any element in lengths is zero.
  • Batch and hidden dimensions can have a size of zero. In these cases the corresponding dimensions in the output also have a size of zero.
ArgumentTypeDescription
XFloats2dThe concatenated sequences.
lengthsInts1dThe sequence lengths.
RETURNSTuple[Floats2d,Ints1d]The last vector of each sequence and the indices of the last sequence elements.

Ops.backprop_reduce_last method

  • default:
  • numpy: default
  • cupy: default

Backpropagate the reduce_last operation.

ArgumentTypeDescription
d_lastsFloats2dThe gradient of the outputs.
lastsInts1dIndices of the last sequence elements.
RETURNSFloats2dThe gradient of the concatenated sequences.

Ops.reduce_sum method

  • default:
  • numpy:
  • cupy:

Perform sequence-wise summation for data in the ragged format.

  • Zero-length sequences are reduced to all-zero vectors.
  • Batch and hidden dimensions can have a size of zero. In these cases the corresponding dimensions in the output also have a size of zero.
ArgumentTypeDescription
XFloats2dThe concatenated sequences.
lengthsInts1dThe sequence lengths.
RETURNSFloats2dThe sequence-wise summations.

Ops.backprop_reduce_sum method

  • default:
  • numpy:
  • cupy:

Backpropagate the reduce_sum operation.

ArgumentTypeDescription
d_sumsFloats2dThe gradient of the outputs.
lengthsInts1dThe sequence lengths.
RETURNSFloats2dThe gradient of the concatenated sequences.

Ops.reduce_mean method

  • default:
  • numpy:
  • cupy:

Perform sequence-wise averaging for data in the ragged format.

  • Zero-length sequences are reduced to all-zero vectors.
  • Batch and hidden dimensions can have a size of zero. In these cases the corresponding dimensions in the output also have a size of zero.
ArgumentTypeDescription
XFloats2dThe concatenated sequences.
lengthsInts1dThe sequence lengths.
RETURNSFloats2dThe sequence-wise averages.

Ops.backprop_reduce_mean method

  • default:
  • numpy:
  • cupy:

Backpropagate the reduce_mean operation.

ArgumentTypeDescription
d_meansFloats2dThe gradient of the outputs.
lengthsInts1dThe sequence lengths.
RETURNSFloats2dThe gradient of the concatenated sequences.

Ops.reduce_max method

  • default:
  • numpy:
  • cupy:

Perform sequence-wise max pooling for data in the ragged format. Zero-length sequences are not allowed.

  • Zero-length sequences are not allowed. A ValueError is raised if any element in lengths is zero.
  • Batch and hidden dimensions can have a size of zero. In these cases the corresponding dimensions in the output also have a size of zero.
ArgumentTypeDescription
XFloats2dThe concatenated sequences.
lengthsInts1dThe sequence lengths.
RETURNSTuple[Floats2d, Ints2d]The sequence-wise maximums.

Ops.backprop_reduce_max method

  • default:
  • numpy:
  • cupy:

Backpropagate the reduce_max operation.

ArgumentTypeDescription
d_maxesFloats2dThe gradient of the outputs.
whichInts2dThe indices selected.
lengthsInts1dThe sequence lengths.
RETURNSFloats2dThe gradient of the concatenated sequences.

Ops.hash method

  • default:
  • numpy:
  • cupy:

Hash a sequence of 64-bit keys into a table with four 32-bit keys, using murmurhash3.

ArgumentTypeDescription
idsInts1dThe keys, 64-bit unsigned integers.
seedintThe hashing seed.
RETURNSInts2dThe hashes.

Ops.ngrams method

  • default:
  • numpy:
  • cupy: default

Create hashed ngram features.

ArgumentTypeDescription
nintThe window to calculate each feature over.
keysInts1dThe input sequence.
RETURNSInts1dThe hashed ngrams.

Ops.gather_add methodNew: v8.1

  • default:
  • numpy:
  • cupy:

Gather rows from table with shape (T, O) using array indices with shape (B, K), then sum the resulting array with shape (B, K, O) over the K axis.

ArgumentTypeDescription
tableFloats2dThe array to increment.
indicesInts2dThe indices to use.
RETURNSFloats2dThe summed rows.

Ops.scatter_add method

  • default:
  • numpy:
  • cupy:

Increment entries in the array out using the indices in ids and the values in inputs.

ArgumentTypeDescription
tableFloatsXdThe array to increment.
indicesIntsXdThe indices to use.
valuesFloatsXdThe inputs.
RETURNSFloatsXdThe incremented array.

Utilities

get_ops function

Get a backend object using a string name.

Examplefrom thinc.api import get_ops

numpy_ops = get_ops("numpy")
ArgumentTypeDescription
opsstr"numpy" or "cupy".
**kwargsOptional arguments passed to Ops.__init__.
RETURNSOpsThe backend object.

use_ops contextmanager

Change the backend to execute with for the scope of the block.

Examplefrom thinc.api import use_ops, get_current_ops

with use_ops("cupy"):
    current_ops = get_current_ops()
    assert current_ops.name == "cupy"
ArgumentTypeDescription
opsstr"numpy" or "cupy".
**kwargsOptional arguments passed to Ops.__init__.

get_current_ops function

Get the current backend object.

ArgumentTypeDescription
RETURNSOpsThe current backend object.

set_current_ops function

Set the current backend object.

ArgumentTypeDescription
opsOpsThe backend object.

set_gpu_allocator function

Set the CuPy GPU memory allocator.

ArgumentTypeDescription
allocatorstrEither "pytorch" or "tensorflow".
Examplefrom thinc.api set_gpu_allocator

set_gpu_allocator("pytorch")