Paper
Training event-based neural networks with exact gradients via Differentiable ODE Solving in JAX
Authors
Lukas König, Manuel Kuhn, David Kappel, Anand Subramoney
Abstract
Existing frameworks for gradient-based training of spiking neural networks face a trade-off: discrete-time methods using surrogate gradients support arbitrary neuron models but introduce gradient bias and constrain spike-time resolution, while continuous-time methods that compute exact gradients require analytical expressions for spike times and state evolution, restricting them to simple neuron types such as Leaky Integrate and Fire (LIF). We introduce the Eventax framework, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling. Built in JAX, our frame-work uses Diffrax ODE-solvers to compute gradients that are exact with respect to the forward simulation for any neuron model defined by ODEs . It also provides a simple API where users can specify just the neuron dynamics, spike conditions, and reset rules. Eventax prioritises modelling flexibility, supporting a wide range of neuron models, loss functions, and network architectures, which can be easily extended. We demonstrate Eventax on multiple benchmarks, including Yin-Yang and MNIST, using diverse neuron models such as Leaky Integrate-and-fire (LIF), Quadratic Integrate-and-fire (QIF), Exponential integrate-and-fire (EIF), Izhikevich and Event-based Gated Recurrent Unit (EGRU) with both time-to-first-spike and state-based loss functions, demonstrating its utility for prototyping and testing event-based architectures trained with exact gradients. We also demonstrate the application of this framework for more complex neuron types by implementing a multi-compartment neuron that uses a model of dendritic spikes in human layer 2/3 cortical Pyramidal neurons for computation. Code available at https://github.com/efficient-scalable-machine-learning/eventax.
Metadata
Related papers
Cosmic Shear in Effective Field Theory at Two-Loop Order: Revisiting $S_8$ in Dark Energy Survey Data
Shi-Fan Chen, Joseph DeRose, Mikhail M. Ivanov, Oliver H. E. Philcox • 2026-03-30
Stop Probing, Start Coding: Why Linear Probes and Sparse Autoencoders Fail at Compositional Generalisation
Vitória Barin Pacela, Shruti Joshi, Isabela Camacho, Simon Lacoste-Julien, Da... • 2026-03-30
SNID-SAGE: A Modern Framework for Interactive Supernova Classification and Spectral Analysis
Fiorenzo Stoppa, Stephen J. Smartt • 2026-03-30
Acoustic-to-articulatory Inversion of the Complete Vocal Tract from RT-MRI with Various Audio Embeddings and Dataset Sizes
Sofiane Azzouz, Pierre-André Vuissoz, Yves Laprie • 2026-03-30
Rotating black hole shadows in metric-affine bumblebee gravity
Jose R. Nascimento, Ana R. M. Oliveira, Albert Yu. Petrov, Paulo J. Porfírio,... • 2026-03-30
Raw Data (Debug)
{
"raw_xml": "<entry>\n <id>http://arxiv.org/abs/2603.08146v1</id>\n <title>Training event-based neural networks with exact gradients via Differentiable ODE Solving in JAX</title>\n <updated>2026-03-09T09:25:52Z</updated>\n <link href='https://arxiv.org/abs/2603.08146v1' rel='alternate' type='text/html'/>\n <link href='https://arxiv.org/pdf/2603.08146v1' rel='related' title='pdf' type='application/pdf'/>\n <summary>Existing frameworks for gradient-based training of spiking neural networks face a trade-off: discrete-time methods using surrogate gradients support arbitrary neuron models but introduce gradient bias and constrain spike-time resolution, while continuous-time methods that compute exact gradients require analytical expressions for spike times and state evolution, restricting them to simple neuron types such as Leaky Integrate and Fire (LIF). We introduce the Eventax framework, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling. Built in JAX, our frame-work uses Diffrax ODE-solvers to compute gradients that are exact with respect to the forward simulation for any neuron model defined by ODEs . It also provides a simple API where users can specify just the neuron dynamics, spike conditions, and reset rules. Eventax prioritises modelling flexibility, supporting a wide range of neuron models, loss functions, and network architectures, which can be easily extended. We demonstrate Eventax on multiple benchmarks, including Yin-Yang and MNIST, using diverse neuron models such as Leaky Integrate-and-fire (LIF), Quadratic Integrate-and-fire (QIF), Exponential integrate-and-fire (EIF), Izhikevich and Event-based Gated Recurrent Unit (EGRU) with both time-to-first-spike and state-based loss functions, demonstrating its utility for prototyping and testing event-based architectures trained with exact gradients. We also demonstrate the application of this framework for more complex neuron types by implementing a multi-compartment neuron that uses a model of dendritic spikes in human layer 2/3 cortical Pyramidal neurons for computation. Code available at https://github.com/efficient-scalable-machine-learning/eventax.</summary>\n <category scheme='http://arxiv.org/schemas/atom' term='cs.LG'/>\n <published>2026-03-09T09:25:52Z</published>\n <arxiv:comment>9 pages, 3 figures</arxiv:comment>\n <arxiv:primary_category term='cs.LG'/>\n <author>\n <name>Lukas König</name>\n </author>\n <author>\n <name>Manuel Kuhn</name>\n </author>\n <author>\n <name>David Kappel</name>\n </author>\n <author>\n <name>Anand Subramoney</name>\n </author>\n </entry>"
}