refactor: standardize CSV loading from ./datasets and improve robustness

- Load all CSVs from ./datasets directory
- Add encoding_errors='ignore' for resilient CSV parsing
- Ensure prompt generators are converted to lists before sampling
This commit is contained in:
Hanyin
2025-05-19 16:19:38 +08:00
parent 444f908009
commit 335787d40e
+13 -7
View File
@@ -248,13 +248,13 @@ def load_jailbreak_v28k() -> ProbeDataset:
@cache_to_disk()
def load_local_csv() -> ProbeDataset:
"""Load prompts from local CSV files."""
csv_files = [f for f in os.listdir(".") if f.endswith(".csv")]
csv_files = [f for f in os.listdir("./datasets") if f.endswith(".csv")]
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}")
prompts = []
for file in csv_files:
try:
df = pd.read_csv(file)
df = pd.read_csv(os.path.join("./datasets", file), encoding_errors="ignore")
if "prompt" in df.columns:
prompts.extend(df["prompt"].tolist())
else:
@@ -270,7 +270,7 @@ def load_csv(file: str) -> ProbeDataset:
"""Load prompts from local CSV files."""
prompts = []
try:
df = pd.read_csv(file)
df = pd.read_csv(os.path.join("./datasets", file), encoding_errors="ignore")
prompts = df["prompt"].tolist()
if "prompt" in df.columns:
prompts.extend(df["prompt"].tolist())
@@ -284,14 +284,14 @@ def load_csv(file: str) -> ProbeDataset:
@cache_to_disk(1)
def load_local_csv_files() -> list[ProbeDataset]:
"""Load prompts from local CSV files and return a list of ProbeDataset objects."""
csv_files = [f for f in os.listdir(".") if f.endswith(".csv")]
csv_files = [f for f in os.listdir("./datasets") if f.endswith(".csv")]
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}")
datasets = []
for file in csv_files:
try:
df = pd.read_csv(file)
df = pd.read_csv(os.path.join("./datasets", file), encoding_errors="ignore")
if "prompt" in df.columns:
prompts = df["prompt"].tolist()
datasets.append(create_probe_dataset(file, prompts, {"src": file}))
@@ -336,12 +336,18 @@ class StenographyTransformer:
)
continue
size = len(dataset.prompts)
# Convert async generator (if any) to a list
prompts = list(dataset.prompts) if hasattr(dataset.prompts, '__iter__') else []
size = len(prompts)
if size == 0:
continue
sample_size = max(1, int(size * 0.25)) # At least 1 sample
for name, fn in self.TRANSFORMATIONS.items():
logger.info(f"Applying {name} to {dataset.dataset_name}")
sampled = random.sample(dataset.prompts, min(sample_size, size))
sampled = random.sample(prompts, min(sample_size, size))
try:
transformed = [
f"Follow the instructions in {name}: {fn(str(p))}"