mirror of
https://github.com/wiltodelta/remove-ai-watermarks.git
synced 2026-06-05 02:28:00 +02:00
feat(device): support xpu backend (#24)
* feat(device): support xpu backend * Fall back to CPU seed generator when device RNG unsupported (xpu) Some torch-xpu builds have no device-side RNG, so torch.Generator(device="xpu") raises when --seed is used. _make_seed_generator tries the device generator and falls back to a backend-agnostic CPU generator. Adds a fallback unit test. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Victor Kuznetsov <kuznetsov.va@gmail.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
1598c499fe
commit
0c7ff1874e
@@ -28,6 +28,15 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
gpu = [
|
||||
"torch>=2.0.0",
|
||||
# The default PyPI torch wheel is a CPU/CUDA build. To drive an Intel GPU
|
||||
# (Arc / Data Center) via ``--device xpu`` you need an XPU-enabled torch
|
||||
# from PyTorch's XPU wheel index (Linux/Windows only -- there is no macOS
|
||||
# XPU build). Install that build first, then this extra (torch is then
|
||||
# already satisfied and won't be re-pulled):
|
||||
# pip install torch --index-url https://download.pytorch.org/whl/xpu
|
||||
# pip install 'remove-ai-watermarks[gpu]'
|
||||
# uv users can target the ``pytorch-xpu`` index declared under [tool.uv]:
|
||||
# uv pip install torch --index-url https://download.pytorch.org/whl/xpu
|
||||
"diffusers>=0.38.0",
|
||||
# diffusers 0.38's auto-pipeline registry imports ``Qwen3VLForConditional
|
||||
# Generation`` (its ``nucleusmoe_image`` pipeline), which only exists in
|
||||
@@ -87,6 +96,15 @@ all = ["remove-ai-watermarks[gpu,detect,trustmark,lama,dev]"]
|
||||
[tool.uv]
|
||||
prerelease = "allow"
|
||||
|
||||
# PyTorch Intel-GPU (XPU) wheel index. ``explicit = true`` keeps it inert for
|
||||
# the default CPU/CUDA install: uv consults it only when a torch install
|
||||
# explicitly targets it (see the ``gpu`` extra comment), so it does not alter
|
||||
# the locked CPU/CUDA resolution. Linux/Windows only -- no macOS XPU build.
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-xpu"
|
||||
url = "https://download.pytorch.org/whl/xpu"
|
||||
explicit = true
|
||||
|
||||
[project.scripts]
|
||||
remove-ai-watermarks = "remove_ai_watermarks.cli:main"
|
||||
|
||||
|
||||
@@ -448,7 +448,12 @@ def cmd_erase(
|
||||
default="default",
|
||||
help="Pipeline profile (default=SDXL, ctrlregen=CtrlRegen).",
|
||||
)
|
||||
@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.")
|
||||
@click.option(
|
||||
"--device",
|
||||
type=click.Choice(["auto", "cpu", "mps", "cuda", "xpu"]),
|
||||
default="auto",
|
||||
help="Inference device.",
|
||||
)
|
||||
@click.option("--seed", type=int, default=None, help="Random seed for reproducibility.")
|
||||
@click.option("--hf-token", type=str, default=None, help="HuggingFace API token.")
|
||||
@click.option(
|
||||
@@ -675,7 +680,12 @@ def cmd_identify(ctx: click.Context, source: Path, no_visible: bool, as_json: bo
|
||||
help="Pipeline profile (default=SDXL, ctrlregen=CtrlRegen).",
|
||||
)
|
||||
@click.option("--model", type=str, default=None, help="HuggingFace model ID for invisible removal.")
|
||||
@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.")
|
||||
@click.option(
|
||||
"--device",
|
||||
type=click.Choice(["auto", "cpu", "mps", "cuda", "xpu"]),
|
||||
default="auto",
|
||||
help="Inference device.",
|
||||
)
|
||||
@click.option("--seed", type=int, default=None, help="Random seed for reproducibility.")
|
||||
@click.option("--hf-token", type=str, default=None, help="HuggingFace API token.")
|
||||
@click.option(
|
||||
@@ -957,7 +967,12 @@ def _process_batch_image(
|
||||
default="default",
|
||||
help="Pipeline profile (default=SDXL, ctrlregen=CtrlRegen).",
|
||||
)
|
||||
@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.")
|
||||
@click.option(
|
||||
"--device",
|
||||
type=click.Choice(["auto", "cpu", "mps", "cuda", "xpu"]),
|
||||
default="auto",
|
||||
help="Inference device.",
|
||||
)
|
||||
@click.option("--seed", type=int, default=None, help="Random seed for reproducibility.")
|
||||
@click.option("--hf-token", type=str, default=None, help="HuggingFace API token.")
|
||||
@click.option(
|
||||
|
||||
@@ -88,7 +88,7 @@ class InvisibleEngine:
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID. None = use default for pipeline.
|
||||
device: Device for inference (auto/cpu/mps/cuda). None = auto.
|
||||
device: Device for inference (auto/cpu/mps/cuda/xpu). None = auto.
|
||||
pipeline: Pipeline profile. "default" (SDXL base, defeats SynthID
|
||||
v2) or "ctrlregen" (CtrlRegen).
|
||||
hf_token: HuggingFace API token.
|
||||
|
||||
@@ -222,6 +222,19 @@ def get_device() -> str:
|
||||
return "cuda"
|
||||
except (AssertionError, RuntimeError):
|
||||
pass
|
||||
# Intel GPU (Arc / Data Center) via the torch XPU backend. The torch.xpu
|
||||
# namespace exists in stock wheels, but is_available() is only True on an
|
||||
# XPU-enabled build (download.pytorch.org/whl/xpu), so this is inert on the
|
||||
# default CPU/CUDA install. Checked before the nvidia-smi path so an Intel
|
||||
# box never triggers the CUDA reinstaller.
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore
|
||||
try:
|
||||
t = torch.tensor([1.0], device="xpu")
|
||||
_ = t + t
|
||||
del t
|
||||
return "xpu"
|
||||
except (AssertionError, RuntimeError):
|
||||
pass
|
||||
if _has_nvidia_gpu() and not os.environ.get(_CUDA_FIX_ENV_KEY):
|
||||
_reinstall_torch_cuda_and_restart()
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
@@ -229,6 +242,20 @@ def get_device() -> str:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _make_seed_generator(device: str, seed: int) -> Any:
|
||||
"""Build a seeded ``torch.Generator``, falling back to a CPU generator.
|
||||
|
||||
Some backends have no device-side RNG (notably certain torch-xpu builds),
|
||||
so ``torch.Generator(device="xpu")`` can raise. A CPU generator is
|
||||
backend-agnostic and still seeds the pipeline reproducibly, so fall back to
|
||||
it rather than failing the run when ``--seed`` is used on such a device.
|
||||
"""
|
||||
try:
|
||||
return torch.Generator(device=device).manual_seed(seed) # type: ignore
|
||||
except (RuntimeError, TypeError):
|
||||
return torch.Generator().manual_seed(seed) # type: ignore
|
||||
|
||||
|
||||
# Keep legacy name available for backwards compatibility
|
||||
_detect_model_profile_from_id = detect_model_profile
|
||||
|
||||
@@ -245,7 +272,7 @@ class WatermarkRemover:
|
||||
|
||||
Attributes:
|
||||
model_id: HuggingFace model ID for the diffusion model.
|
||||
device: Device to run inference on (cuda, mps, or cpu).
|
||||
device: Device to run inference on (cuda, xpu, mps, or cpu).
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL_ID = DEFAULT_MODEL_ID
|
||||
@@ -270,8 +297,8 @@ class WatermarkRemover:
|
||||
self.device = (device or get_device()).lower()
|
||||
if self.device == "auto":
|
||||
self.device = get_device()
|
||||
if self.device not in {"cpu", "mps", "cuda"}:
|
||||
raise ValueError(f"Unsupported device '{device}'. Use one of: auto, cpu, mps, cuda.")
|
||||
if self.device not in {"cpu", "mps", "cuda", "xpu"}:
|
||||
raise ValueError(f"Unsupported device '{device}'. Use one of: auto, cpu, mps, cuda, xpu.")
|
||||
if torch_dtype is None:
|
||||
if self.device == "cpu" or self.device == "mps":
|
||||
self.torch_dtype = torch.float32 # type: ignore
|
||||
@@ -435,7 +462,7 @@ class WatermarkRemover:
|
||||
generator = None
|
||||
if seed is not None and _HAS_TORCH:
|
||||
self._set_progress(f"Setting reproducible seed: {seed}")
|
||||
generator = torch.Generator(device=self.device).manual_seed(seed) # type: ignore
|
||||
generator = _make_seed_generator(self.device, seed)
|
||||
|
||||
effective_steps = max(1, int(num_inference_steps * strength))
|
||||
self._set_progress(
|
||||
@@ -564,7 +591,7 @@ class WatermarkRemover:
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
self._set_progress("Loading Differential-Diffusion pipeline (protect-text)...")
|
||||
use_fp16 = self.device in {"mps", "cuda"}
|
||||
use_fp16 = self.device in {"mps", "cuda", "xpu"}
|
||||
load_kwargs: dict[str, Any] = {
|
||||
"custom_pipeline": _DIFF_PIPELINE_NAME,
|
||||
"custom_revision": _DIFF_PIPELINE_REVISION,
|
||||
|
||||
+45
-2
@@ -6,7 +6,7 @@ code paths work correctly on CPU, MPS (macOS), and CUDA (Linux/Windows).
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -27,7 +27,7 @@ class TestDeviceDetection:
|
||||
|
||||
def test_returns_valid_device(self):
|
||||
device = get_device()
|
||||
assert device in ("cpu", "mps", "cuda")
|
||||
assert device in ("cpu", "mps", "cuda", "xpu")
|
||||
|
||||
def test_cpu_fallback_when_no_gpu(self):
|
||||
"""On CI / machines without GPU, should fall back to cpu or mps."""
|
||||
@@ -39,6 +39,49 @@ class TestDeviceDetection:
|
||||
def test_no_torch_returns_cpu(self):
|
||||
assert get_device() == "cpu"
|
||||
|
||||
def test_xpu_selected_when_available(self):
|
||||
"""An XPU-enabled torch (no CUDA) routes to the Intel GPU backend.
|
||||
|
||||
The whole torch module is mocked so the smoke-test ops succeed without
|
||||
any real device; cuda must read False so the cuda branch is skipped.
|
||||
"""
|
||||
fake_torch = MagicMock()
|
||||
fake_torch.cuda.is_available.return_value = False
|
||||
fake_torch.xpu.is_available.return_value = True
|
||||
with patch("remove_ai_watermarks.noai.watermark_remover.torch", fake_torch):
|
||||
assert get_device() == "xpu"
|
||||
fake_torch.tensor.assert_called_with([1.0], device="xpu")
|
||||
|
||||
def test_init_accepts_xpu_and_selects_fp16(self):
|
||||
"""WatermarkRemover accepts device='xpu' and picks fp16 (not fp32)."""
|
||||
if not is_watermark_removal_available():
|
||||
pytest.skip("torch/diffusers not installed")
|
||||
import torch
|
||||
|
||||
from remove_ai_watermarks.noai.watermark_remover import WatermarkRemover
|
||||
|
||||
remover = WatermarkRemover(device="xpu")
|
||||
assert remover.device == "xpu"
|
||||
assert remover.torch_dtype == torch.float16
|
||||
|
||||
def test_seed_generator_falls_back_to_cpu_when_device_rng_unsupported(self):
|
||||
"""A device with no RNG backend (e.g. some torch-xpu builds) falls back
|
||||
to a CPU generator instead of raising when --seed is used."""
|
||||
from remove_ai_watermarks.noai import watermark_remover as wr
|
||||
|
||||
def fake_generator(device="cpu"):
|
||||
if device == "xpu":
|
||||
raise RuntimeError("Device type xpu is not supported for torch.Generator()")
|
||||
gen = MagicMock()
|
||||
gen.manual_seed.return_value = f"gen:{device}"
|
||||
return gen
|
||||
|
||||
fake_torch = MagicMock()
|
||||
fake_torch.Generator.side_effect = fake_generator
|
||||
with patch.object(wr, "torch", fake_torch):
|
||||
assert wr._make_seed_generator("xpu", 123) == "gen:cpu"
|
||||
assert wr._make_seed_generator("cuda", 123) == "gen:cuda"
|
||||
|
||||
|
||||
class TestMpsErrorDetection:
|
||||
"""Tests for MPS error detection helper."""
|
||||
|
||||
Reference in New Issue
Block a user