mirror of
https://github.com/dongdongunique/EvoSynth.git
synced 2026-02-13 01:32:46 +00:00
34 lines
1.2 KiB
Python
34 lines
1.2 KiB
Python
from abc import ABC, abstractmethod
|
||
from typing import Any, Dict
|
||
|
||
class BaseModel(ABC):
|
||
"""
|
||
所有模型的抽象基类。
|
||
定义了所有模型(无论是本地白盒还是远程API黑盒)都必须遵守的接口。
|
||
"""
|
||
def __init__(self, **kwargs):
|
||
"""允许在配置文件中传入任意模型特定的参数。"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def query(self, text_input: str, image_input: Any = None) -> str:
|
||
"""
|
||
向模型发送查询并获取响应的核心方法。
|
||
对于纯文本模型,image_input 将被忽略。
|
||
这是所有模型都必须实现的功能。
|
||
"""
|
||
pass
|
||
|
||
def get_gradients(self, inputs) -> Dict:
|
||
"""
|
||
(可选) 获取梯度,用于白盒攻击。
|
||
如果模型不支持(如黑盒API模型),则直接抛出异常。
|
||
"""
|
||
raise NotImplementedError(f"{self.__class__.__name__} does not support gradient access.")
|
||
|
||
def get_embeddings(self, inputs) -> Any:
|
||
"""
|
||
(可选) 获取嵌入向量,用于白盒或灰盒攻击。
|
||
如果模型不支持,则直接抛出异常。
|
||
"""
|
||
raise NotImplementedError(f"{self.__class__.__name__} does not support embedding access.") |