GPU Acceleration
GPU training is supported, but you must install the backend package for your hardware first.
julia
using Pkg
Pkg.add("LuxCUDA")
# or
Pkg.add(["CUDA", "cuDNN"])julia
using Pkg
Pkg.add("AMDGPU")
# If your use case fails, consider Reactant.jl for GPU support.julia
using Pkg
Pkg.add("Metal")julia
using Pkg
Pkg.add("oneAPI")Then run the following to access a device:
julia
using EasyHybrid, LuxCUDA
gpu_device()julia
using EasyHybrid, AMDGPU
gpu_device()julia
using EasyHybrid, Metal
gpu_device()julia
using EasyHybrid, oneAPI
gpu_device()In your training call, pass arch = GPU(). For example:
julia
using EasyHybrid, Metal
train(...; arch = GPU())That is all you need. Your hybrid model will now train on the GPU.