feat(add more tests):

This commit is contained in:
Alexander Myasoedov
2024-12-02 23:49:30 +02:00
parent 10dc91060f
commit d365113440
3 changed files with 63 additions and 6 deletions
@@ -0,0 +1,53 @@
from unittest.mock import patch
import pandas as pd
import pytest
from .model import RefusalClassifier
@pytest.fixture
def mock_training_data():
"""Create mock training data CSV content"""
data = {
"GPT4_response": ["I cannot help with that", "I must decline"],
"ChatGPT_response": ["I won't assist with that", "That's not appropriate"],
"Claude_response": ["I cannot comply", "That would be unethical"],
}
return pd.DataFrame(data)
@pytest.fixture
def classifier():
"""Create a RefusalClassifier instance with test paths"""
return RefusalClassifier(
model_path="test_model.joblib",
vectorizer_path="test_vectorizer.joblib",
scaler_path="test_scaler.joblib",
)
@pytest.fixture
def trained_classifier(classifier, mock_training_data):
"""Create a trained classifier with mock data"""
with patch("pandas.read_csv", return_value=mock_training_data):
classifier.train(["mock_data.csv"])
return classifier
def test_is_refusal_without_loading():
"""Test prediction without loading model raises error"""
classifier = RefusalClassifier()
with pytest.raises(ValueError, match="Model, vectorizer, or scaler not loaded"):
classifier.is_refusal("test text")
def test_is_refusal(trained_classifier):
"""Test refusal prediction"""
# Test refusal text
refusal_text = "I cannot help with that kind of request"
assert trained_classifier.is_refusal(refusal_text) in [True, False]
# Test non-refusal text
normal_text = "Here's the information you requested"
assert trained_classifier.is_refusal(normal_text) in [True, False]
+4 -2
View File
@@ -1,7 +1,9 @@
from pathlib import Path
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from unittest.mock import patch, Mock
from pathlib import Path
from .report import router
client = TestClient(router)
+6 -4
View File
@@ -1,9 +1,11 @@
import pytest
from fastapi.testclient import TestClient
from pathlib import Path
from ..models.schemas import Settings
from .static import router, get_static_file
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
from ..models.schemas import Settings
from .static import get_static_file, router
client = TestClient(router)