mirror of
https://github.com/wiltodelta/remove-ai-watermarks.git
synced 2026-06-12 21:47:47 +02:00
feat(device): support xpu backend (#24)
* feat(device): support xpu backend * Fall back to CPU seed generator when device RNG unsupported (xpu) Some torch-xpu builds have no device-side RNG, so torch.Generator(device="xpu") raises when --seed is used. _make_seed_generator tries the device generator and falls back to a backend-agnostic CPU generator. Adds a fallback unit test. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Victor Kuznetsov <kuznetsov.va@gmail.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
1598c499fe
commit
0c7ff1874e
@@ -448,7 +448,12 @@ def cmd_erase(
|
||||
default="default",
|
||||
help="Pipeline profile (default=SDXL, ctrlregen=CtrlRegen).",
|
||||
)
|
||||
@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.")
|
||||
@click.option(
|
||||
"--device",
|
||||
type=click.Choice(["auto", "cpu", "mps", "cuda", "xpu"]),
|
||||
default="auto",
|
||||
help="Inference device.",
|
||||
)
|
||||
@click.option("--seed", type=int, default=None, help="Random seed for reproducibility.")
|
||||
@click.option("--hf-token", type=str, default=None, help="HuggingFace API token.")
|
||||
@click.option(
|
||||
@@ -675,7 +680,12 @@ def cmd_identify(ctx: click.Context, source: Path, no_visible: bool, as_json: bo
|
||||
help="Pipeline profile (default=SDXL, ctrlregen=CtrlRegen).",
|
||||
)
|
||||
@click.option("--model", type=str, default=None, help="HuggingFace model ID for invisible removal.")
|
||||
@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.")
|
||||
@click.option(
|
||||
"--device",
|
||||
type=click.Choice(["auto", "cpu", "mps", "cuda", "xpu"]),
|
||||
default="auto",
|
||||
help="Inference device.",
|
||||
)
|
||||
@click.option("--seed", type=int, default=None, help="Random seed for reproducibility.")
|
||||
@click.option("--hf-token", type=str, default=None, help="HuggingFace API token.")
|
||||
@click.option(
|
||||
@@ -957,7 +967,12 @@ def _process_batch_image(
|
||||
default="default",
|
||||
help="Pipeline profile (default=SDXL, ctrlregen=CtrlRegen).",
|
||||
)
|
||||
@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.")
|
||||
@click.option(
|
||||
"--device",
|
||||
type=click.Choice(["auto", "cpu", "mps", "cuda", "xpu"]),
|
||||
default="auto",
|
||||
help="Inference device.",
|
||||
)
|
||||
@click.option("--seed", type=int, default=None, help="Random seed for reproducibility.")
|
||||
@click.option("--hf-token", type=str, default=None, help="HuggingFace API token.")
|
||||
@click.option(
|
||||
|
||||
@@ -88,7 +88,7 @@ class InvisibleEngine:
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID. None = use default for pipeline.
|
||||
device: Device for inference (auto/cpu/mps/cuda). None = auto.
|
||||
device: Device for inference (auto/cpu/mps/cuda/xpu). None = auto.
|
||||
pipeline: Pipeline profile. "default" (SDXL base, defeats SynthID
|
||||
v2) or "ctrlregen" (CtrlRegen).
|
||||
hf_token: HuggingFace API token.
|
||||
|
||||
@@ -222,6 +222,19 @@ def get_device() -> str:
|
||||
return "cuda"
|
||||
except (AssertionError, RuntimeError):
|
||||
pass
|
||||
# Intel GPU (Arc / Data Center) via the torch XPU backend. The torch.xpu
|
||||
# namespace exists in stock wheels, but is_available() is only True on an
|
||||
# XPU-enabled build (download.pytorch.org/whl/xpu), so this is inert on the
|
||||
# default CPU/CUDA install. Checked before the nvidia-smi path so an Intel
|
||||
# box never triggers the CUDA reinstaller.
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore
|
||||
try:
|
||||
t = torch.tensor([1.0], device="xpu")
|
||||
_ = t + t
|
||||
del t
|
||||
return "xpu"
|
||||
except (AssertionError, RuntimeError):
|
||||
pass
|
||||
if _has_nvidia_gpu() and not os.environ.get(_CUDA_FIX_ENV_KEY):
|
||||
_reinstall_torch_cuda_and_restart()
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
@@ -229,6 +242,20 @@ def get_device() -> str:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _make_seed_generator(device: str, seed: int) -> Any:
|
||||
"""Build a seeded ``torch.Generator``, falling back to a CPU generator.
|
||||
|
||||
Some backends have no device-side RNG (notably certain torch-xpu builds),
|
||||
so ``torch.Generator(device="xpu")`` can raise. A CPU generator is
|
||||
backend-agnostic and still seeds the pipeline reproducibly, so fall back to
|
||||
it rather than failing the run when ``--seed`` is used on such a device.
|
||||
"""
|
||||
try:
|
||||
return torch.Generator(device=device).manual_seed(seed) # type: ignore
|
||||
except (RuntimeError, TypeError):
|
||||
return torch.Generator().manual_seed(seed) # type: ignore
|
||||
|
||||
|
||||
# Keep legacy name available for backwards compatibility
|
||||
_detect_model_profile_from_id = detect_model_profile
|
||||
|
||||
@@ -245,7 +272,7 @@ class WatermarkRemover:
|
||||
|
||||
Attributes:
|
||||
model_id: HuggingFace model ID for the diffusion model.
|
||||
device: Device to run inference on (cuda, mps, or cpu).
|
||||
device: Device to run inference on (cuda, xpu, mps, or cpu).
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL_ID = DEFAULT_MODEL_ID
|
||||
@@ -270,8 +297,8 @@ class WatermarkRemover:
|
||||
self.device = (device or get_device()).lower()
|
||||
if self.device == "auto":
|
||||
self.device = get_device()
|
||||
if self.device not in {"cpu", "mps", "cuda"}:
|
||||
raise ValueError(f"Unsupported device '{device}'. Use one of: auto, cpu, mps, cuda.")
|
||||
if self.device not in {"cpu", "mps", "cuda", "xpu"}:
|
||||
raise ValueError(f"Unsupported device '{device}'. Use one of: auto, cpu, mps, cuda, xpu.")
|
||||
if torch_dtype is None:
|
||||
if self.device == "cpu" or self.device == "mps":
|
||||
self.torch_dtype = torch.float32 # type: ignore
|
||||
@@ -435,7 +462,7 @@ class WatermarkRemover:
|
||||
generator = None
|
||||
if seed is not None and _HAS_TORCH:
|
||||
self._set_progress(f"Setting reproducible seed: {seed}")
|
||||
generator = torch.Generator(device=self.device).manual_seed(seed) # type: ignore
|
||||
generator = _make_seed_generator(self.device, seed)
|
||||
|
||||
effective_steps = max(1, int(num_inference_steps * strength))
|
||||
self._set_progress(
|
||||
@@ -564,7 +591,7 @@ class WatermarkRemover:
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
self._set_progress("Loading Differential-Diffusion pipeline (protect-text)...")
|
||||
use_fp16 = self.device in {"mps", "cuda"}
|
||||
use_fp16 = self.device in {"mps", "cuda", "xpu"}
|
||||
load_kwargs: dict[str, Any] = {
|
||||
"custom_pipeline": _DIFF_PIPELINE_NAME,
|
||||
"custom_revision": _DIFF_PIPELINE_REVISION,
|
||||
|
||||
Reference in New Issue
Block a user