pydpf.resampling.OptimalTransportResampler#

class pydpf.resampling.OptimalTransportResampler(regularisation: float, decay_rate: float, min_update_size: float, max_iterations: int, transport_gradient_clip: float)#

Bases: Module

Optimal transport based resampling

Optimal transport resampling produces a differentiable deterministic transport map from the proposal distribution to the posterior. This is achieved by finding the solution to an entropy regularised Kantorovich optimal transport problem between the two empirical distributions. The particles are transformed by the resulting optimal map to obtain a new unweighted approximation of the posterior.

Parameters:
regularisation: float
The minimum strength of the entropy regularisation, in our implementation regularisation automatically chosen per sample and

exponentially decayed to this value.

decay_rate: float

The factor by which to decrease the entropy regularisation per Sinkhorn loop.

min_update_size: float

The size of update to the transport potentials below which iteration should stop.

max_iterations: int

The maximum number iterations of the Sinkhorn loop, before stopping. Regardless of convergence.

transport_gradient_clip: float

The maximum per-element gradient of the transport matrix that should be passed. Higher valued gradients will be clipped to this value.

Notes

Our implementation is closely based on the original code of Thornton and Corenflos, the following details being taken from theirs: We exponentially decay the regularisation strength over the Sinkhorn iterations. We chose the initial strength of the regularisation parameter to be equal to maximum value minus the minimum value in the particle-state 2D array of the particle positions after each dimension is normalised to standard deviation 1. For numerical stability we cap the magnitude of the contribution to the gradient due to the transport matrix.

Optimal transport resampling places particles in new positions on \(\mathbb{R}^n\), so it cannot directly be applied when some component of the state space is discrete/categorical.

Optimal transport resampling results in biased (but asymptotically consistent) estimates of all non-affine functions of the latent state. Including the likelihood. The authors of the proposing paper investigate this effect and find it sufficiently small to ignore. See their paper [1] for details.

References

[1]

Corenflos et al., Differentiable Particle Filtering via Entropy-Regularized Optimal Transport, 2021

__init__(regularisation: float, decay_rate: float, min_update_size: float, max_iterations: int, transport_gradient_clip: float)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Methods

__init__(regularisation, decay_rate, ...)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

diameter(x)

Calculates the diameter of the data.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extent(x)

extra_repr()

Return the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(state, weight, **data)

Run the optimal transport resampler.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_sinkhorn_inputs_OT(Nk, log_weight, x_t)

Get the inputs to the Sinkhorn algorithm as used for OT resampling

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules([remove_duplicate])

Return an iterator over all modules in the network.

mtia([device])

Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook(hook)

Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_post_hook(hook)

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

set_submodule(target, module[, strict])

Set the submodule given by target if it exists, otherwise throw an error.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

update()

Update all constrained_parameters and cached_properties belonging to this Module.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

Attributes

T_destination

call_super_init

dump_patches

training

static diameter(x: Tensor)#

Calculates the diameter of the data. The diameter is defined as the maximum of the standard deviation across a sample across data dimensions.

Parameters:
x: Tensor

Input tensor.

Returns:
diameter: Tensor

The diameter of the data per batch.

static extent(x: Tensor)#
forward(state: Tensor, weight: Tensor, **data)#

Run the optimal transport resampler.

static get_sinkhorn_inputs_OT(Nk, log_weight: Tensor, x_t: Tensor) Tuple[Tensor, Tensor, Tensor]#

Get the inputs to the Sinkhorn algorithm as used for OT resampling

Parameters:
log_weights: (B,N) Tensor

The particle weights

N: int

Number of particles

x_t: (B,N,D) Tensor

The particle state

Returns:
log_uniform_weights: (B,N) Tensor

A tensor of \(log(1/N)\)

cost_matrix: (B, N, N) Tensor

The auto-distance matrix of scaled_x_t under the 2-Norm.

scale_x: (B, N, D) Tensor

The amount the particles where scaled by in calculating the cost matrix.