[Neuron] Add AWS Neuron (Trainium/Inferentia) as an officially supported device#13289
[Neuron] Add AWS Neuron (Trainium/Inferentia) as an officially supported device#13289JingyaHuang wants to merge 40 commits into
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
… into add-neuron-backend
… into add-neuron-backend
|
@claude-2-serge could you do a review please? |
There was a problem hiding this comment.
🤗 Serge says:
This PR adds AWS Neuron (Trainium/Inferentia) as a supported backend. The core plumbing (import detection, backend dispatch tables, randn_tensor workaround, dtype fallbacks) follows existing patterns. However, there are several issues ranging from dead code in tests to an inconsistency between the pixart and SDXL neuron checks, and a missing BACKEND_SUPPORTS_TRAINING entry in the test dispatch tables.
Correctness
-
Dead
atolin integration tests — BothFlux2KleinPipelineIntegrationTestsandStableDiffusionXLTurboPipelineIntegrationTestscomputeatolbut never use it in any assertion. The_ = atolline is a no-op. This looks like an incomplete implementation — either the tolerance should be used in anassertLessEqualagainst an expected slice, or the lines should be removed. -
Inconsistent neuron detection in pixart vs SDXL — The pixart pipeline uses
is_torch_neuronx_available()(a package-availability check) to decidetimestep_device, while the SDXL pipeline correctly usesdevice.type == "neuron"(a runtime device check). Iftorch_neuronxis installed but the pipeline runs on CUDA, the pixart pipeline would unnecessarily force timesteps to CPU. The SDXL approach is more correct. -
BACKEND_SUPPORTS_TRAININGnot patched for neuron intests/testing_utils.py— The_neuron_devicekey is added to all other dispatch tables but not toBACKEND_SUPPORTS_TRAINING. This meansbackend_supports_training(torch_device)falls back to"default": True, contradictingtorch_utils.pywhere"neuron": False.
PR Description Mismatch
- The description claims
enable_neuron_compileandneuron_warmupmethods are added topipeline_utils.py, but neither exists in the diff or the codebase. Either these were dropped from the PR or the description is stale.
Tests — Environment Variable Leakage
setUpinFlux2KleinPipelineIntegrationTestssetsTORCH_NEURONX_NEFF_CACHE_DIRandTORCH_NEURONX_ENABLE_NKI_SDPAbuttearDownnever cleans them up. Minor for slow tests but still a leak.
Suggestions / Dead Code Analysis
- The
image_slicevariable is computed in both new integration tests but only used for astd()check — no expected-slice comparison is performed. Combined with the deadatol, these tests only verify shape and range, not numerical correctness. This is weak for integration tests.
35 LLM turns · 42 tool calls · 201.4s · 926751 in / 8552 out tokens
|
Thanks @sayakpaul, I addressed the review comments, shall I merge? |
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks! i left some questions/comments
|
|
||
| # 4. Prepare timesteps | ||
| if XLA_AVAILABLE: | ||
| is_neuron_device = hasattr(device, "type") and device.type == "neuron" |
There was a problem hiding this comment.
@claude-2-serge can you trace the code for how device is defined here and let me know if there is any code path that _exeuction_devide would not return torch.device type?
There was a problem hiding this comment.
Follow-up inline comments are something we're working with @tarekziade. Should land pretty soon.
… into add-neuron-backend
Co-authored-by: YiYi Xu <yixu310@gmail.com>
JingyaHuang
left a comment
There was a problem hiding this comment.
addressed the comments! @sayakpaul @yiyixuxu
… into add-neuron-backend
|
The failing CI checks are unrelated to this PR, ok for us to merge? @sayakpaul @yiyixuxu |
What does this PR do?
This PR adds AWS Neuron (Trainium/Inferentia) as an officially supported compute backend in Diffusers, on par with existing backends like CUDA, MPS, XPU, and MLU.
Changes
Usage
Validation
So far we validated the following models, the idea is to manually validated a part of representative models, and ensure the model coverage with pur CIs afterward.
Next Steps
torch.compileon Neuron device