It is not difficult to write up a kernel \(k_{\theta}(x,y)\) in Jax and there are many ways to do so. Here I’ll offer my method, which is to just maintain a class module with self-contained trainable parameters, just as one might write a custom neural network layer. Therefore, we don’t have to carry around parameters elsewhere, and if we want to keep them fixed, we can.
Likely you’d want a kernel utility which evaluates the kernel given two locations of the same n-dimensions, which evaluates the kernel for all pairs of such points for two vectors (making a kernel/Gram matrix), and which evaluates the gradient of the kernel w.r.t. \(x\) or \(y\).
Here is a Squared Exponential / Gaussian kernel class in (my favorite Jax ML library) Equinox, which does all of these things.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Gaussian(eqx.Module):
scale: jax.Array
def __init__(self, *, key):
key,_ = jr.split(key)
self.scale = jr.uniform(key,minval=-3.,maxval=0.0)
def eval(self, x, y):
scale = jax.nn.softplus(self.scale)
return jnp.exp(-(x-y)@(x-y) / scale)
def __call__(self, x, y):
if x.ndim == 1 or y.ndim == 1:
ndims = 1
else:
ndims = x.shape[-1]
X,Y = x.reshape(-1, ndims), y.reshape(-1, ndims)
k_xy = jax.vmap(jax.vmap(self.eval, (0, None)), (None, 0))(Y,X)
return k_xy
def grad(self, x,y):
if x.ndim == 1 or y.ndim == 1:
ndims = 1
else:
ndims = x.shape[-1]
X,Y = x.reshape(-1, ndims), y.reshape(-1, ndims)
def grad_wrt_y(x, y):
return jax.grad(self.eval, argnums=1)(x, y,).squeeze()
k_xy = jax.vmap(jax.vmap(grad_wrt_y, (0, None)), (None, 0))(Y,X)
return k_xy
Picking this apart, we’re initializing Gaussian’s scale parameter by sampling a uniform distribution (line 6), which is in the negative range because we softplus transform it before use to ensure positivity (remove this if you so desire). By default any float or jax array in equinox is trainable if its assigned as a class variable, whereas an int or something would not be.
The eval function evaluates the kernel for a pair of coordinates x and y. Their shapes should be (ndims,), and the output shape should be (,). The call function uses a nifty double vmap to make a kernel matrix for two vectors x and y. A given row would be the kernel evaluations of a given index into x and all the y values. If x and y are (n_points,) then the spatial dimension is assumed to be 1. Otherwise they should be of the shape (n_points, n_dims).
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class WendlandC4Kernel(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
key = rngs.params()
self.scale = nnx.Param(inverse_softplus_perturbed_ones(key, (1,)))
def eval(self, x, y):
scale = nnx.softplus(self.scale.value)
r = jnp.sqrt((x - y) @ (x - y).T) / scale
return jnp.where(r < 1, ((1 - r) ** 6) * (3 + 18 * r + 35 * r**2), 0)
def __call__(self, x, y):
if x.ndim == 1 or y.ndim == 1:
ndims = 1
else:
ndims = x.shape[-1]
X, Y = x.reshape(-1, ndims), y.reshape(-1, ndims)
k_xy = jax.vmap(jax.vmap(self.eval, (0, None)), (None, 0))(Y, X)
k_xy = k_xy.squeeze()
return k_xy
def grad(self, x,y):
if x.ndim == 1 or y.ndim == 1:
ndims = 1
else:
ndims = x.shape[-1]
X,Y = x.reshape(-1, ndims), y.reshape(-1, ndims)
def grad_wrt_y(x, y):
return jax.grad(self.eval, argnums=1)(x, y,).squeeze()
k_xy = jax.vmap(jax.vmap(grad_wrt_y, (0, None)), (None, 0))(Y,X)
return k_xy