Skip to content

GPU Acceleration

GPU training is supported, but you must install the backend package for your hardware first.

julia
using Pkg
Pkg.add("LuxCUDA")
# or
Pkg.add(["CUDA", "cuDNN"])
julia
using Pkg
Pkg.add("AMDGPU")
# If your use case fails, consider Reactant.jl for GPU support.
julia
using Pkg
Pkg.add("Metal")
julia
using Pkg
Pkg.add("oneAPI")

Then run the following to access a device:

julia
using EasyHybrid, LuxCUDA
gpu_device()
julia
using EasyHybrid, AMDGPU
gpu_device()
julia
using EasyHybrid, Metal
gpu_device()
julia
using EasyHybrid, oneAPI
gpu_device()

In your training call, pass arch = GPU(). For example:

julia
using EasyHybrid, Metal

train(...; arch = GPU())

That is all you need. Your hybrid model will now train on the GPU.