diff --git a/pyproject.toml b/pyproject.toml index e0324dc..b837743 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,15 @@ dependencies = [ [project.optional-dependencies] gpu = [ "torch>=2.0.0", + # The default PyPI torch wheel is a CPU/CUDA build. To drive an Intel GPU + # (Arc / Data Center) via ``--device xpu`` you need an XPU-enabled torch + # from PyTorch's XPU wheel index (Linux/Windows only -- there is no macOS + # XPU build). Install that build first, then this extra (torch is then + # already satisfied and won't be re-pulled): + # pip install torch --index-url https://download.pytorch.org/whl/xpu + # pip install 'remove-ai-watermarks[gpu]' + # uv users can target the ``pytorch-xpu`` index declared under [tool.uv]: + # uv pip install torch --index-url https://download.pytorch.org/whl/xpu "diffusers>=0.38.0", # diffusers 0.38's auto-pipeline registry imports ``Qwen3VLForConditional # Generation`` (its ``nucleusmoe_image`` pipeline), which only exists in @@ -87,6 +96,15 @@ all = ["remove-ai-watermarks[gpu,detect,trustmark,lama,dev]"] [tool.uv] prerelease = "allow" +# PyTorch Intel-GPU (XPU) wheel index. ``explicit = true`` keeps it inert for +# the default CPU/CUDA install: uv consults it only when a torch install +# explicitly targets it (see the ``gpu`` extra comment), so it does not alter +# the locked CPU/CUDA resolution. Linux/Windows only -- no macOS XPU build. +[[tool.uv.index]] +name = "pytorch-xpu" +url = "https://download.pytorch.org/whl/xpu" +explicit = true + [project.scripts] remove-ai-watermarks = "remove_ai_watermarks.cli:main" diff --git a/src/remove_ai_watermarks/cli.py b/src/remove_ai_watermarks/cli.py index 4bcbf41..9adc927 100644 --- a/src/remove_ai_watermarks/cli.py +++ b/src/remove_ai_watermarks/cli.py @@ -448,7 +448,12 @@ def cmd_erase( default="default", help="Pipeline profile (default=SDXL, ctrlregen=CtrlRegen).", ) -@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.") +@click.option( + "--device", + type=click.Choice(["auto", "cpu", "mps", "cuda", "xpu"]), + default="auto", + help="Inference device.", +) @click.option("--seed", type=int, default=None, help="Random seed for reproducibility.") @click.option("--hf-token", type=str, default=None, help="HuggingFace API token.") @click.option( @@ -675,7 +680,12 @@ def cmd_identify(ctx: click.Context, source: Path, no_visible: bool, as_json: bo help="Pipeline profile (default=SDXL, ctrlregen=CtrlRegen).", ) @click.option("--model", type=str, default=None, help="HuggingFace model ID for invisible removal.") -@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.") +@click.option( + "--device", + type=click.Choice(["auto", "cpu", "mps", "cuda", "xpu"]), + default="auto", + help="Inference device.", +) @click.option("--seed", type=int, default=None, help="Random seed for reproducibility.") @click.option("--hf-token", type=str, default=None, help="HuggingFace API token.") @click.option( @@ -957,7 +967,12 @@ def _process_batch_image( default="default", help="Pipeline profile (default=SDXL, ctrlregen=CtrlRegen).", ) -@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.") +@click.option( + "--device", + type=click.Choice(["auto", "cpu", "mps", "cuda", "xpu"]), + default="auto", + help="Inference device.", +) @click.option("--seed", type=int, default=None, help="Random seed for reproducibility.") @click.option("--hf-token", type=str, default=None, help="HuggingFace API token.") @click.option( diff --git a/src/remove_ai_watermarks/invisible_engine.py b/src/remove_ai_watermarks/invisible_engine.py index 9fc0d0f..e9051d8 100644 --- a/src/remove_ai_watermarks/invisible_engine.py +++ b/src/remove_ai_watermarks/invisible_engine.py @@ -88,7 +88,7 @@ class InvisibleEngine: Args: model_id: HuggingFace model ID. None = use default for pipeline. - device: Device for inference (auto/cpu/mps/cuda). None = auto. + device: Device for inference (auto/cpu/mps/cuda/xpu). None = auto. pipeline: Pipeline profile. "default" (SDXL base, defeats SynthID v2) or "ctrlregen" (CtrlRegen). hf_token: HuggingFace API token. diff --git a/src/remove_ai_watermarks/noai/watermark_remover.py b/src/remove_ai_watermarks/noai/watermark_remover.py index 558504b..7ab43b7 100644 --- a/src/remove_ai_watermarks/noai/watermark_remover.py +++ b/src/remove_ai_watermarks/noai/watermark_remover.py @@ -222,6 +222,19 @@ def get_device() -> str: return "cuda" except (AssertionError, RuntimeError): pass + # Intel GPU (Arc / Data Center) via the torch XPU backend. The torch.xpu + # namespace exists in stock wheels, but is_available() is only True on an + # XPU-enabled build (download.pytorch.org/whl/xpu), so this is inert on the + # default CPU/CUDA install. Checked before the nvidia-smi path so an Intel + # box never triggers the CUDA reinstaller. + if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore + try: + t = torch.tensor([1.0], device="xpu") + _ = t + t + del t + return "xpu" + except (AssertionError, RuntimeError): + pass if _has_nvidia_gpu() and not os.environ.get(_CUDA_FIX_ENV_KEY): _reinstall_torch_cuda_and_restart() if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): @@ -229,6 +242,20 @@ def get_device() -> str: return "cpu" +def _make_seed_generator(device: str, seed: int) -> Any: + """Build a seeded ``torch.Generator``, falling back to a CPU generator. + + Some backends have no device-side RNG (notably certain torch-xpu builds), + so ``torch.Generator(device="xpu")`` can raise. A CPU generator is + backend-agnostic and still seeds the pipeline reproducibly, so fall back to + it rather than failing the run when ``--seed`` is used on such a device. + """ + try: + return torch.Generator(device=device).manual_seed(seed) # type: ignore + except (RuntimeError, TypeError): + return torch.Generator().manual_seed(seed) # type: ignore + + # Keep legacy name available for backwards compatibility _detect_model_profile_from_id = detect_model_profile @@ -245,7 +272,7 @@ class WatermarkRemover: Attributes: model_id: HuggingFace model ID for the diffusion model. - device: Device to run inference on (cuda, mps, or cpu). + device: Device to run inference on (cuda, xpu, mps, or cpu). """ DEFAULT_MODEL_ID = DEFAULT_MODEL_ID @@ -270,8 +297,8 @@ class WatermarkRemover: self.device = (device or get_device()).lower() if self.device == "auto": self.device = get_device() - if self.device not in {"cpu", "mps", "cuda"}: - raise ValueError(f"Unsupported device '{device}'. Use one of: auto, cpu, mps, cuda.") + if self.device not in {"cpu", "mps", "cuda", "xpu"}: + raise ValueError(f"Unsupported device '{device}'. Use one of: auto, cpu, mps, cuda, xpu.") if torch_dtype is None: if self.device == "cpu" or self.device == "mps": self.torch_dtype = torch.float32 # type: ignore @@ -435,7 +462,7 @@ class WatermarkRemover: generator = None if seed is not None and _HAS_TORCH: self._set_progress(f"Setting reproducible seed: {seed}") - generator = torch.Generator(device=self.device).manual_seed(seed) # type: ignore + generator = _make_seed_generator(self.device, seed) effective_steps = max(1, int(num_inference_steps * strength)) self._set_progress( @@ -564,7 +591,7 @@ class WatermarkRemover: from diffusers import DiffusionPipeline self._set_progress("Loading Differential-Diffusion pipeline (protect-text)...") - use_fp16 = self.device in {"mps", "cuda"} + use_fp16 = self.device in {"mps", "cuda", "xpu"} load_kwargs: dict[str, Any] = { "custom_pipeline": _DIFF_PIPELINE_NAME, "custom_revision": _DIFF_PIPELINE_REVISION, diff --git a/tests/test_platform.py b/tests/test_platform.py index 00dfce6..e37a5bd 100644 --- a/tests/test_platform.py +++ b/tests/test_platform.py @@ -6,7 +6,7 @@ code paths work correctly on CPU, MPS (macOS), and CUDA (Linux/Windows). from __future__ import annotations -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -27,7 +27,7 @@ class TestDeviceDetection: def test_returns_valid_device(self): device = get_device() - assert device in ("cpu", "mps", "cuda") + assert device in ("cpu", "mps", "cuda", "xpu") def test_cpu_fallback_when_no_gpu(self): """On CI / machines without GPU, should fall back to cpu or mps.""" @@ -39,6 +39,49 @@ class TestDeviceDetection: def test_no_torch_returns_cpu(self): assert get_device() == "cpu" + def test_xpu_selected_when_available(self): + """An XPU-enabled torch (no CUDA) routes to the Intel GPU backend. + + The whole torch module is mocked so the smoke-test ops succeed without + any real device; cuda must read False so the cuda branch is skipped. + """ + fake_torch = MagicMock() + fake_torch.cuda.is_available.return_value = False + fake_torch.xpu.is_available.return_value = True + with patch("remove_ai_watermarks.noai.watermark_remover.torch", fake_torch): + assert get_device() == "xpu" + fake_torch.tensor.assert_called_with([1.0], device="xpu") + + def test_init_accepts_xpu_and_selects_fp16(self): + """WatermarkRemover accepts device='xpu' and picks fp16 (not fp32).""" + if not is_watermark_removal_available(): + pytest.skip("torch/diffusers not installed") + import torch + + from remove_ai_watermarks.noai.watermark_remover import WatermarkRemover + + remover = WatermarkRemover(device="xpu") + assert remover.device == "xpu" + assert remover.torch_dtype == torch.float16 + + def test_seed_generator_falls_back_to_cpu_when_device_rng_unsupported(self): + """A device with no RNG backend (e.g. some torch-xpu builds) falls back + to a CPU generator instead of raising when --seed is used.""" + from remove_ai_watermarks.noai import watermark_remover as wr + + def fake_generator(device="cpu"): + if device == "xpu": + raise RuntimeError("Device type xpu is not supported for torch.Generator()") + gen = MagicMock() + gen.manual_seed.return_value = f"gen:{device}" + return gen + + fake_torch = MagicMock() + fake_torch.Generator.side_effect = fake_generator + with patch.object(wr, "torch", fake_torch): + assert wr._make_seed_generator("xpu", 123) == "gen:cpu" + assert wr._make_seed_generator("cuda", 123) == "gen:cuda" + class TestMpsErrorDetection: """Tests for MPS error detection helper."""