mirror of
https://github.com/dongdongunique/EvoSynth.git
synced 2026-06-01 20:51:40 +02:00
34 lines
1.2 KiB
Python
34 lines
1.2 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict
|
|
|
|
class BaseModel(ABC):
|
|
"""
|
|
所有模型的抽象基类。
|
|
定义了所有模型(无论是本地白盒还是远程API黑盒)都必须遵守的接口。
|
|
"""
|
|
def __init__(self, **kwargs):
|
|
"""允许在配置文件中传入任意模型特定的参数。"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def query(self, text_input: str, image_input: Any = None) -> str:
|
|
"""
|
|
向模型发送查询并获取响应的核心方法。
|
|
对于纯文本模型,image_input 将被忽略。
|
|
这是所有模型都必须实现的功能。
|
|
"""
|
|
pass
|
|
|
|
def get_gradients(self, inputs) -> Dict:
|
|
"""
|
|
(可选) 获取梯度,用于白盒攻击。
|
|
如果模型不支持(如黑盒API模型),则直接抛出异常。
|
|
"""
|
|
raise NotImplementedError(f"{self.__class__.__name__} does not support gradient access.")
|
|
|
|
def get_embeddings(self, inputs) -> Any:
|
|
"""
|
|
(可选) 获取嵌入向量,用于白盒或灰盒攻击。
|
|
如果模型不支持,则直接抛出异常。
|
|
"""
|
|
raise NotImplementedError(f"{self.__class__.__name__} does not support embedding access.") |