Skip to content

[Neuron] Add AWS Neuron (Trainium/Inferentia) as an officially supported device#13289

Open
JingyaHuang wants to merge 40 commits into
huggingface:mainfrom
JingyaHuang:add-neuron-backend
Open

[Neuron] Add AWS Neuron (Trainium/Inferentia) as an officially supported device#13289
JingyaHuang wants to merge 40 commits into
huggingface:mainfrom
JingyaHuang:add-neuron-backend

Conversation

@JingyaHuang
Copy link
Copy Markdown
Contributor

@JingyaHuang JingyaHuang commented Mar 19, 2026

What does this PR do?

This PR adds AWS Neuron (Trainium/Inferentia) as an officially supported compute backend in Diffusers, on par with existing backends like CUDA, MPS, XPU, and MLU.

Changes

  • import_utils.py — adds is_torch_neuronx_available() detection, following the existing pattern for optional backends.
  • torch_utils.py — registers "neuron" in all backend dispatch tables (BACKEND_SUPPORTS_TRAINING, BACKEND_EMPTY_CACHE, BACKEND_DEVICE_COUNT, BACKEND_MANUAL_SEED, etc.) and adds a randn_tensor workaround since Neuron/XLA does not support creating random tensors directly on device (falls back to CPU).
  • utils/init.py — exports is_torch_neuronx_available.
  • pipeline_utils.py — adds two new DiffusionPipeline methods:
    • enable_neuron_compile(model_names, cache_dir, fullgraph) — wraps pipeline nn.Module components with torch.compile(backend="neuron") for whole-graph NEFF compilation. Supports optional NEFF caching via TORCH_NEURONX_NEFF_CACHE_DIR.
    • neuron_warmup(*args, **kwargs) — runs a single dummy forward pass to trigger upfront neuronx-cc compilation before timed inference.

Usage

  • Eager mode
import torch                                                                                                             
import torch_neuronx  # noqa: F401 — registers torch.neuron                                                            
                                                                                                                           
from diffusers import AutoPipelineForText2Image                                                                          
                                                                                                                           
# Load and move to Neuron device                                                                                         
pipe = AutoPipelineForText2Image.from_pretrained(                                                                        
    "stabilityai/sdxl-turbo",                                                                                            
    torch_dtype=torch.bfloat16,                           
    variant="fp16",                                                                                                      
)
pipe = pipe.to(torch.neuron.current_device())                                                                            
                                                                                                                         
# Warmup                                                                   
pipe(prompt="warmup", height=512, width=512, num_inference_steps=1, guidance_scale=0.0)                                                                                                                        
                                                          
# Inference                                                                                               
image = pipe(                                             
    prompt="a golden retriever surfing a wave, photorealistic",                                                          
    height=512,
    width=512,                                                                                                           
    num_inference_steps=1, 
    guidance_scale=0.0,                                                                    
).images[0]                                                                                                              
                                                                                                                         
image.save("output.png") 

Validation

So far we validated the following models, the idea is to manually validated a part of representative models, and ensure the model coverage with pur CIs afterward.

  • pixart
  • sdxl
  • flux2-klein-4B (able to run within one NeuronCore w/o. tp under eager mode for resolution 1024x1024)

Next Steps

  • Enable torch.compile on Neuron device
  • Add tensor parallel support for memory-bound devices like neuron
  • Tackle the compatibility of diffusers+nki kernels lib to boost the performance on neuron.
  • Support sequence parallel on Neuron for video models.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions github-actions Bot added lora examples size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 9, 2026
@github-actions github-actions Bot added size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 10, 2026
@github-actions github-actions Bot added size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 10, 2026
@github-actions github-actions Bot added size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 10, 2026
@JingyaHuang JingyaHuang marked this pull request as ready for review May 18, 2026 10:34
Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice thank you!

Comment thread tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py Outdated
@sayakpaul sayakpaul requested review from DN6 and yiyixuxu May 18, 2026 10:39
@sayakpaul
Copy link
Copy Markdown
Member

@claude-2-serge could you do a review please?

Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤗 Serge says:

This PR adds AWS Neuron (Trainium/Inferentia) as a supported backend. The core plumbing (import detection, backend dispatch tables, randn_tensor workaround, dtype fallbacks) follows existing patterns. However, there are several issues ranging from dead code in tests to an inconsistency between the pixart and SDXL neuron checks, and a missing BACKEND_SUPPORTS_TRAINING entry in the test dispatch tables.

Correctness

  • Dead atol in integration tests — Both Flux2KleinPipelineIntegrationTests and StableDiffusionXLTurboPipelineIntegrationTests compute atol but never use it in any assertion. The _ = atol line is a no-op. This looks like an incomplete implementation — either the tolerance should be used in an assertLessEqual against an expected slice, or the lines should be removed.

  • Inconsistent neuron detection in pixart vs SDXL — The pixart pipeline uses is_torch_neuronx_available() (a package-availability check) to decide timestep_device, while the SDXL pipeline correctly uses device.type == "neuron" (a runtime device check). If torch_neuronx is installed but the pipeline runs on CUDA, the pixart pipeline would unnecessarily force timesteps to CPU. The SDXL approach is more correct.

  • BACKEND_SUPPORTS_TRAINING not patched for neuron in tests/testing_utils.py — The _neuron_device key is added to all other dispatch tables but not to BACKEND_SUPPORTS_TRAINING. This means backend_supports_training(torch_device) falls back to "default": True, contradicting torch_utils.py where "neuron": False.

PR Description Mismatch

  • The description claims enable_neuron_compile and neuron_warmup methods are added to pipeline_utils.py, but neither exists in the diff or the codebase. Either these were dropped from the PR or the description is stale.

Tests — Environment Variable Leakage

  • setUp in Flux2KleinPipelineIntegrationTests sets TORCH_NEURONX_NEFF_CACHE_DIR and TORCH_NEURONX_ENABLE_NKI_SDPA but tearDown never cleans them up. Minor for slow tests but still a leak.

Suggestions / Dead Code Analysis

  • The image_slice variable is computed in both new integration tests but only used for a std() check — no expected-slice comparison is performed. Combined with the dead atol, these tests only verify shape and range, not numerical correctness. This is weak for integration tests.

35 LLM turns · 42 tool calls · 201.4s · 926751 in / 8552 out tokens

Comment thread tests/pipelines/flux2/test_pipeline_flux2_klein.py Outdated
Comment thread tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py Outdated
Comment thread src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py Outdated
Comment thread tests/testing_utils.py
Comment thread tests/pipelines/flux2/test_pipeline_flux2_klein.py
@JingyaHuang JingyaHuang requested a review from sayakpaul May 18, 2026 15:43
@JingyaHuang
Copy link
Copy Markdown
Contributor Author

JingyaHuang commented May 18, 2026

Thanks @sayakpaul, I addressed the review comments, shall I merge?

Comment thread tests/pipelines/flux2/test_pipeline_flux2_klein.py
Comment thread tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! i left some questions/comments

Comment thread src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py Outdated
Comment thread src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
Comment thread src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py Outdated

# 4. Prepare timesteps
if XLA_AVAILABLE:
is_neuron_device = hasattr(device, "type") and device.type == "neuron"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@claude-2-serge can you trace the code for how device is defined here and let me know if there is any code path that _exeuction_devide would not return torch.device type?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up inline comments are something we're working with @tarekziade. Should land pretty soon.

Copy link
Copy Markdown
Contributor Author

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed the comments! @sayakpaul @yiyixuxu

Comment thread src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py Outdated
Comment thread src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
Comment thread tests/pipelines/flux2/test_pipeline_flux2_klein.py
Comment thread src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py Outdated
@JingyaHuang
Copy link
Copy Markdown
Contributor Author

The failing CI checks are unrelated to this PR, ok for us to merge? @sayakpaul @yiyixuxu

@JingyaHuang JingyaHuang requested review from sayakpaul and yiyixuxu May 26, 2026 15:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants