Blog posts

2024

Trainable kernels in Jax

It is not difficult to write up a kernel \(k_{\theta}(x,y)\) in Jax and there are many ways to do so. Here I’ll offer my method, which is to just maintain a class module with self-contained trainable parameters, just as one might write a custom neural network layer. Therefore, we don’t have to carry around parameters elsewhere, and if we want to keep them fixed, we can.