pydpf.resampling.OptimalTransportResampler#
- class pydpf.resampling.OptimalTransportResampler(regularisation: float, decay_rate: float, min_update_size: float, max_iterations: int, transport_gradient_clip: float)#
Bases:
ModuleOptimal transport based resampling
Optimal transport resampling produces a differentiable deterministic transport map from the proposal distribution to the posterior. This is achieved by finding the solution to an entropy regularised Kantorovich optimal transport problem between the two empirical distributions. The particles are transformed by the resulting optimal map to obtain a new unweighted approximation of the posterior.
- Parameters:
- regularisation: float
- The minimum strength of the entropy regularisation, in our implementation regularisation automatically chosen per sample and
exponentially decayed to this value.
- decay_rate: float
The factor by which to decrease the entropy regularisation per Sinkhorn loop.
- min_update_size: float
The size of update to the transport potentials below which iteration should stop.
- max_iterations: int
The maximum number iterations of the Sinkhorn loop, before stopping. Regardless of convergence.
- transport_gradient_clip: float
The maximum per-element gradient of the transport matrix that should be passed. Higher valued gradients will be clipped to this value.
Notes
Our implementation is closely based on the original code of Thornton and Corenflos, the following details being taken from theirs: We exponentially decay the regularisation strength over the Sinkhorn iterations. We chose the initial strength of the regularisation parameter to be equal to maximum value minus the minimum value in the particle-state 2D array of the particle positions after each dimension is normalised to standard deviation 1. For numerical stability we cap the magnitude of the contribution to the gradient due to the transport matrix.
Optimal transport resampling places particles in new positions on \(\mathbb{R}^n\), so it cannot directly be applied when some component of the state space is discrete/categorical.
Optimal transport resampling results in biased (but asymptotically consistent) estimates of all non-affine functions of the latent state. Including the likelihood. The authors of the proposing paper investigate this effect and find it sufficiently small to ignore. See their paper [1] for details.
References
[1]Corenflos et al., Differentiable Particle Filtering via Entropy-Regularized Optimal Transport, 2021
- __init__(regularisation: float, decay_rate: float, min_update_size: float, max_iterations: int, transport_gradient_clip: float)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Methods
__init__(regularisation, decay_rate, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
add_module(name, module)Add a child module to the current module.
apply(fn)Apply
fnrecursively to every submodule (as returned by.children()) as well as self.bfloat16()Casts all floating point parameters and buffers to
bfloat16datatype.buffers([recurse])Return an iterator over module buffers.
children()Return an iterator over immediate children modules.
compile(*args, **kwargs)Compile this Module's forward using
torch.compile().cpu()Move all model parameters and buffers to the CPU.
cuda([device])Move all model parameters and buffers to the GPU.
diameter(x)Calculates the diameter of the data.
double()Casts all floating point parameters and buffers to
doubledatatype.eval()Set the module in evaluation mode.
extent(x)extra_repr()Return the extra representation of the module.
float()Casts all floating point parameters and buffers to
floatdatatype.forward(state, weight, **data)Run the optimal transport resampler.
get_buffer(target)Return the buffer given by
targetif it exists, otherwise throw an error.get_extra_state()Return any extra state to include in the module's state_dict.
get_parameter(target)Return the parameter given by
targetif it exists, otherwise throw an error.get_sinkhorn_inputs_OT(Nk, log_weight, x_t)Get the inputs to the Sinkhorn algorithm as used for OT resampling
get_submodule(target)Return the submodule given by
targetif it exists, otherwise throw an error.half()Casts all floating point parameters and buffers to
halfdatatype.ipu([device])Move all model parameters and buffers to the IPU.
load_state_dict(state_dict[, strict, assign])Copy parameters and buffers from
state_dictinto this module and its descendants.modules([remove_duplicate])Return an iterator over all modules in the network.
mtia([device])Move all model parameters and buffers to the MTIA.
named_buffers([prefix, recurse, ...])Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children()Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules([memo, prefix, remove_duplicate])Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters([prefix, recurse, ...])Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parameters([recurse])Return an iterator over module parameters.
register_backward_hook(hook)Register a backward hook on the module.
register_buffer(name, tensor[, persistent])Add a buffer to the module.
register_forward_hook(hook, *[, prepend, ...])Register a forward hook on the module.
register_forward_pre_hook(hook, *[, ...])Register a forward pre-hook on the module.
register_full_backward_hook(hook[, prepend])Register a backward hook on the module.
register_full_backward_pre_hook(hook[, prepend])Register a backward pre-hook on the module.
register_load_state_dict_post_hook(hook)Register a post-hook to be run after module's
load_state_dict()is called.register_load_state_dict_pre_hook(hook)Register a pre-hook to be run before module's
load_state_dict()is called.register_module(name, module)Alias for
add_module().register_parameter(name, param)Add a parameter to the module.
register_state_dict_post_hook(hook)Register a post-hook for the
state_dict()method.register_state_dict_pre_hook(hook)Register a pre-hook for the
state_dict()method.requires_grad_([requires_grad])Change if autograd should record operations on parameters in this module.
set_extra_state(state)Set extra state contained in the loaded state_dict.
set_submodule(target, module[, strict])Set the submodule given by
targetif it exists, otherwise throw an error.share_memory()See
torch.Tensor.share_memory_().state_dict(*args[, destination, prefix, ...])Return a dictionary containing references to the whole state of the module.
to(*args, **kwargs)Move and/or cast the parameters and buffers.
to_empty(*, device[, recurse])Move the parameters and buffers to the specified device without copying storage.
train([mode])Set the module in training mode.
type(dst_type)Casts all parameters and buffers to
dst_type.update()Update all constrained_parameters and cached_properties belonging to this Module.
xpu([device])Move all model parameters and buffers to the XPU.
zero_grad([set_to_none])Reset gradients of all model parameters.
Attributes
T_destinationcall_super_initdump_patchestraining- static diameter(x: Tensor)#
Calculates the diameter of the data. The diameter is defined as the maximum of the standard deviation across a sample across data dimensions.
- Parameters:
- x: Tensor
Input tensor.
- Returns:
- diameter: Tensor
The diameter of the data per batch.
- static extent(x: Tensor)#
- forward(state: Tensor, weight: Tensor, **data)#
Run the optimal transport resampler.
- static get_sinkhorn_inputs_OT(Nk, log_weight: Tensor, x_t: Tensor) Tuple[Tensor, Tensor, Tensor]#
Get the inputs to the Sinkhorn algorithm as used for OT resampling
- Parameters:
- log_weights: (B,N) Tensor
The particle weights
- N: int
Number of particles
- x_t: (B,N,D) Tensor
The particle state
- Returns:
- log_uniform_weights: (B,N) Tensor
A tensor of \(log(1/N)\)
- cost_matrix: (B, N, N) Tensor
The auto-distance matrix of scaled_x_t under the 2-Norm.
- scale_x: (B, N, D) Tensor
The amount the particles where scaled by in calculating the cost matrix.