Email the Author
You can use this page to email Derrick Mwiti about JAX and Flax.
About the Book
Book about JAX and Flax. Covering how to perform deep learning in JAX and Flax.
JAX (What it is and how to use it in Python)
What is XLA?
Installing JAX
Setting up TPUs on Google Colab
Data types in JAX
Ways to create JAX arrays
Generating random numbers with JAX
Pure functions
JAX NumPy operations
JAX arrays are immutable
Out-of-Bounds Indexing
Data placement on devices in JAX
How fast is JAX?
Using jit() to speed up functions
How JIT works
Taking derivatives with grad()
Auto-vectorization with vmap
Parallelization with pmap
Debugging NANs in JAX
Double (64bit) precision
What is a pytree?
Handling state in JAX
Loading datasets with JAX
Building neural networks with JAX
Final thoughts
Optimizers in JAX and Flax
Adaptive vs stochastic gradient descent (SGD) optimizers
AdaBelief
AdaGrad
Adam – Adaptive moment estimation
AdamW
RAdam – Rectified Adam optimizer
AdaFactor
Fromage
Lamb – Layerwise adaptive large batch optimization
Lars – Layer-wise Adaptive Rate Scaling
SM3 - Square-root of Minima of Sums of Maxima of Squared-gradients
Method
SGD– Stochastic Gradient Descent
Noisy SGD
Optimistic GD
Differentially Private SGD
RMSProp
Yogi
Final thoughts
JAX loss functions
What is a loss function?
Creating custom loss functions in JAX
Which loss functions are available in JAX?
Sigmoid binary cross entropy
Softmax cross entropy
Cosine distance
Cosine similarity
Huber loss
l2 loss
log cosh
Smooth labels
Computing loss with JAX Metrics
How to monitor JAX loss functions
Why JAX loss nan happens
Final thoughts
Activation functions in JAX and Flax
ReLU – Rectified linear unit
PReLU– Parametric Rectified Linear Unit
Sigmoid
Log sigmoid
Softmax
Log softmax
ELU – Exponential linear unit activation
CELU – Continuously-differentiable exponential linear unit
GELU– Gaussian error linear unit activation
GLU – Gated linear unit activation
Soft sign
Softplus
Swish–Sigmoid Linear Unit( SiLU)
Custom activation functions in JAX and Flax
Final thoughts
How to load datasets in JAX with TensorFlow
How to load text data in JAX
Clean the text data
Label encode the sentiment column
Text preprocessing with TensorFlow
How to load image data in JAX
How to load CSV data in JAX
Final thoughts
Image classification with JAX & Flax
Loading the dataset
Define Convolution Neural Network with Flax
Define loss
Compute metrics
Create training state
Define training step
Define evaluation step
Training function
Evaluate the model
Train and evaluate the model
Model performance
Final thoughts
Distributed training with JAX & Flax
Perform standard imports
Setup TPUs on Colab
Download the dataset
Load the dataset
Define the model with Flax
Create training state
Apply the model
Training function
Train the model
Model evaluation
Final thoughts
How to use TensorBoard in JAX & Flax
How to use TensorBoard
How to install TensorBoard
Using TensorBoard with Jupyter notebooks and Google Colab
How to launch TensorBoard
Tensorboard dashboards
How to use TensorBoard with Flax
How to log images with TensorBoard in Flax
How to log text with TensorBoard in Flax
Track model training in JAX using TensorBoard
How to profile JAX programs with TensorBoard
Programmatic profiling
Manual profiling with TensorBoard
How to profile JAX program on a remote machine
Share TensorBoard dashboards
Final thoughts
Handling state in JAX & Flax (BatchNorm and DropOut layers)
Perform standard imports
Download the dataset
Loading datasets in JAX
Data processing with PyTorch
Define Flax model with BatchNorm and DropOut
Create loss function
Compute metrics
Create custom Flax training state
Training step
Evaluation step
Train Flax model
Set up TensorBoard in Flax
Train model
Save Flax model
Load Flax model
Evaluate Flax model
Visualize Flax model performance
Final thoughts
LSTM in JAX & Flax
Dataset download
Data processing with NLTK
Text vectorization with Keras
Create tf.data dataset
Define LSTM model in Flax
Compute metrics in Flax
Create training state
Define training step
Evaluate the Flax model
Create training function
Train LSTM model in Flax
Visualize LSTM model performance in Flax
Save LSTM model
Final thoughts
Flax vs. TensorFlow
Random number generation in TensorFlow and Flax
Model definition in Flax and TensorFlow
Activations in Flax and TensorFlow
Optimizers in Flax and TensorFlow
Metrics in Flax and TensorFlow
Computing gradients in Flax and TensorFlow
Loading datasets in Flax and TensorFlow
Training model in Flax vs. TensorFlow
Distributed training in Flax and TensorFlow
Working with TPU accelerators
Model evaluation
Visualize model performance
Final thoughts
Train ResNet in Flax from scratch(Distributed ResNet training)
Install Flax models
Perform standard imports
Download dataset
Loading dataset in Flax
Data transformation in Flax
Instantiate Flax ResNet model
Compute metrics
Create Flax model training state
Apply model function
TensorBoard in Flax
Train Flax ResNet model
Evaluate model with TensorBoard
Visualize Flax model performance
Save Flax ResNet model
Load Flax RestNet model
Final thoughts
Transfer learning with JAX & Flax
Install JAX ResNet
Download dataset
Data loading in JAX
Data processing
ResNet model definition
Create head network
Combine ResNet backbone with head
Load pre-trained ResNet 50
Get model and variables
Zero gradients
Define Flax optimizer
Define Flax loss function
Compute Flax metrics
Create Flax training state
Training step
Evaluation step
Train ResNet model in Flax
Set up TensorBoard in Flax
Train model
Save Flax model
Load saved Flax model
Evaluate Flax ResNet model
Visualize model performance
Final thoughts
Elegy(High-level API for deep learning in JAX & Flax)
Data pre-processing
Model definition in Elegy
Elegy model summary
Distributed training in Elegy
Keras-like callbacks in Flax
Train Elegy models
Evaluate Elegy models
Visualize Elegy model with TensorBoard
Plot model performance with Matplotlib
Making predictions with Elegy models
Saving and loading Elegy models
Final thoughts
Appendix
About the Author
Derrick Mwiti is a Data Scientist with expertise in machine learning, data analytics, and visualization. He enjoys working with data to derive meaningful insights that help business executives in decision making. He is an alumnus of the prestigious Meltwater Entrepreneurial School of Technology(MEST). Derrick is an avid contributor to the data science community. He does so by contributing to popular data science publications such as KDnuggets, Heartbeat, and Datacamp just to mention a few. He holds a Bachelor of Science in Mathematics and Computer Science from Multimedia University.