quality_test.py

Code Hygiene Score: 65

Issues 4

Zeile Typ Beschreibung
249 magic_number Magic Number gefunden: 60
251 magic_number Magic Number gefunden: 60
278 magic_number Magic Number gefunden: 100
287 magic_number Magic Number gefunden: 60

Dependencies 13

Funktionen 8

Code

#!/usr/bin/env python3
"""
Quality comparison test for different LLM models in the pipeline.
Tests entity extraction, relation extraction, and taxonomy classification.
"""

import json
import os
import sys
import time

import requests

# Add pipeline directory to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from constants import (
    LLM_MAX_TOKENS,
    MS_PER_SECOND,
    PROMPT_TEXT_LIMIT,
    PROMPT_TEXT_LIMIT_SHORT,
    TEST_TIMEOUT,
)
from db import db

OLLAMA_HOST = "http://localhost:11434"

# Models to test
MODELS = {
    "gemma3": "gemma3:27b-it-qat",
    "anthropic": "claude-opus-4-5-20251101",
}


def get_anthropic_client():
    """Get Anthropic API client."""
    import anthropic

    api_key = os.environ.get("ANTHROPIC_API_KEY", "")
    if not api_key:
        env_file = "/var/www/dev.campus.systemische-tools.de/.env"
        if os.path.exists(env_file):
            with open(env_file) as f:
                for line in f:
                    if line.startswith("ANTHROPIC_API_KEY="):
                        api_key = line.split("=", 1)[1].strip()
                        break
    return anthropic.Anthropic(api_key=api_key) if api_key else None


def run_ollama(model, prompt, timeout=TEST_TIMEOUT):
    """Run prompt through Ollama model."""
    start = time.time()
    try:
        response = requests.post(
            f"{OLLAMA_HOST}/api/generate",
            json={
                "model": model,
                "prompt": prompt,
                "stream": False,
                "format": "json",
                "options": {"temperature": 0.3, "num_predict": LLM_MAX_TOKENS},
            },
            timeout=timeout,
        )
        response.raise_for_status()
        data = response.json()
        elapsed = time.time() - start
        return {
            "response": data.get("response", ""),
            "tokens": data.get("eval_count", 0),
            "duration_ms": elapsed * MS_PER_SECOND,
            "success": True,
        }
    except Exception as e:
        return {"response": "", "error": str(e), "success": False, "duration_ms": (time.time() - start) * MS_PER_SECOND}


def run_anthropic(client, prompt, model="claude-opus-4-5-20251101"):
    """Run prompt through Anthropic model."""
    start = time.time()
    try:
        message = client.messages.create(
            model=model, max_tokens=LLM_MAX_TOKENS, messages=[{"role": "user", "content": prompt}]
        )
        elapsed = time.time() - start
        return {
            "response": message.content[0].text,
            "tokens": message.usage.input_tokens + message.usage.output_tokens,
            "input_tokens": message.usage.input_tokens,
            "output_tokens": message.usage.output_tokens,
            "duration_ms": elapsed * MS_PER_SECOND,
            "success": True,
        }
    except Exception as e:
        return {"response": "", "error": str(e), "success": False, "duration_ms": (time.time() - start) * MS_PER_SECOND}


def extract_entities(text, model_name, model_id, client=None):
    """Extract entities using specified model."""
    prompt = f"""Analysiere den folgenden deutschen Text und extrahiere alle wichtigen Entitäten.

Kategorisiere jede Entität als:
- PERSON (Namen von Personen)
- ORGANIZATION (Firmen, Institutionen, Gruppen)
- CONCEPT (Fachbegriffe, Methoden, Theorien)
- LOCATION (Orte, Länder)
- DATE (Zeitangaben)
- OTHER (Sonstiges)

Antworte NUR im JSON-Format:
{{"entities": [{{"name": "...", "type": "...", "context": "kurze Beschreibung"}}]}}

Text:
{text[:PROMPT_TEXT_LIMIT]}
"""

    if model_name == "anthropic":
        result = run_anthropic(client, prompt, model_id)
    else:
        result = run_ollama(model_id, prompt)

    # Parse JSON
    entities = []
    if result["success"]:
        try:
            import re

            json_match = re.search(r"\{[\s\S]*\}", result["response"])
            if json_match:
                data = json.loads(json_match.group())
                entities = data.get("entities", [])
        except (json.JSONDecodeError, AttributeError):
            pass  # JSON parsing failed, keep empty entities

    result["entities"] = entities
    result["entity_count"] = len(entities)
    return result


def classify_taxonomy(text, model_name, model_id, client=None):
    """Classify text into taxonomy categories."""
    prompt = f"""Klassifiziere den folgenden Text in passende Kategorien.

Wähle aus diesen Hauptkategorien:
- Methoden (Therapiemethoden, Coaching-Techniken)
- Theorie (Konzepte, Modelle, Grundlagen)
- Praxis (Anwendung, Fallbeispiele, Übungen)
- Organisation (Strukturen, Prozesse, Rollen)
- Kommunikation (Gesprächsführung, Interaktion)
- Entwicklung (Persönliche Entwicklung, Veränderung)
- Teamarbeit (Teamdynamik, Zusammenarbeit)

Antworte NUR im JSON-Format:
{{"categories": ["...", "..."], "confidence": 0.0-1.0, "reasoning": "kurze Begründung"}}

Text:
{text[:PROMPT_TEXT_LIMIT_SHORT]}
"""

    if model_name == "anthropic":
        result = run_anthropic(client, prompt, model_id)
    else:
        result = run_ollama(model_id, prompt)

    # Parse JSON
    categories = []
    confidence = 0
    reasoning = ""
    if result["success"]:
        try:
            import re

            json_match = re.search(r"\{[\s\S]*\}", result["response"])
            if json_match:
                data = json.loads(json_match.group())
                categories = data.get("categories", [])
                confidence = data.get("confidence", 0)
                reasoning = data.get("reasoning", "")
        except (json.JSONDecodeError, AttributeError):
            pass  # JSON parsing failed, keep empty categories

    result["categories"] = categories
    result["confidence"] = confidence
    result["reasoning"] = reasoning
    return result


def generate_questions(text, model_name, model_id, client=None):
    """Generate quiz questions from text."""
    prompt = f"""Erstelle 3 Verständnisfragen zu folgendem Lerntext.
Die Fragen sollen das Verständnis der Kernkonzepte prüfen.

Antworte NUR im JSON-Format:
{{"questions": [
  {{"question": "...", "answer": "...", "difficulty": "leicht|mittel|schwer"}}
]}}

Text:
{text[:PROMPT_TEXT_LIMIT_SHORT]}
"""

    if model_name == "anthropic":
        result = run_anthropic(client, prompt, model_id)
    else:
        result = run_ollama(model_id, prompt)

    # Parse JSON
    questions = []
    if result["success"]:
        try:
            import re

            json_match = re.search(r"\{[\s\S]*\}", result["response"])
            if json_match:
                data = json.loads(json_match.group())
                questions = data.get("questions", [])
        except (json.JSONDecodeError, AttributeError):
            pass  # JSON parsing failed, keep empty questions

    result["questions"] = questions
    result["question_count"] = len(questions)
    return result


def run_quality_test(document_id):
    """Run full quality comparison test."""
    db.connect()

    # Get document content
    cursor = db.execute(
        """SELECT c.content FROM chunks c
           WHERE c.document_id = %s
           ORDER BY c.chunk_index""",
        (document_id,),
    )
    chunks = cursor.fetchall()
    cursor.close()

    full_text = "\n\n".join([c["content"] for c in chunks])
    print(f"Dokument geladen: {len(full_text)} Zeichen, {len(chunks)} Chunks\n")

    # Get Anthropic client
    anthropic_client = get_anthropic_client()

    results = {}

    for model_name, model_id in MODELS.items():
        print(f"\n{'=' * 60}")
        print(f"TESTE: {model_name} ({model_id})")
        print("=" * 60)

        results[model_name] = {"model_id": model_id, "tests": {}}

        # Skip Anthropic if no client
        client = anthropic_client if model_name == "anthropic" else None
        if model_name == "anthropic" and not client:
            print("  ÜBERSPRUNGEN: Kein Anthropic API Key")
            continue

        # Test 1: Entity Extraction
        print("\n[1/3] Entity Extraction...")
        entity_result = extract_entities(full_text, model_name, model_id, client)
        results[model_name]["tests"]["entities"] = entity_result
        print(f"  → {entity_result['entity_count']} Entitäten gefunden ({entity_result['duration_ms']:.0f}ms)")
        if entity_result.get("entities"):
            for e in entity_result["entities"][:5]:
                print(f"     • {e.get('name', '?')} ({e.get('type', '?')})")

        # Test 2: Taxonomy Classification
        print("\n[2/3] Taxonomy Classification...")
        taxonomy_result = classify_taxonomy(full_text, model_name, model_id, client)
        results[model_name]["tests"]["taxonomy"] = taxonomy_result
        print(
            f"  → Kategorien: {', '.join(taxonomy_result['categories'])} (Konfidenz: {taxonomy_result['confidence']})"
        )
        if taxonomy_result.get("reasoning"):
            print(f"     Begründung: {taxonomy_result['reasoning'][:100]}...")

        # Test 3: Question Generation
        print("\n[3/3] Question Generation...")
        question_result = generate_questions(full_text, model_name, model_id, client)
        results[model_name]["tests"]["questions"] = question_result
        print(f"  → {question_result['question_count']} Fragen generiert ({question_result['duration_ms']:.0f}ms)")
        if question_result.get("questions"):
            for q in question_result["questions"][:3]:
                print(f"     Q: {q.get('question', '?')[:60]}...")

    db.disconnect()
    return results


def print_report(results):
    """Print detailed comparison report."""
    print("\n")
    print("=" * 80)
    print("QUALITÄTSREPORT: Pipeline Output-Vergleich")
    print("=" * 80)

    # Entity comparison
    print("\n### 1. ENTITY EXTRACTION ###\n")
    print(f"{'Modell':<20} {'Entitäten':>10} {'Zeit (ms)':>12} {'Tokens':>10}")
    print("-" * 55)
    for model, data in results.items():
        if "entities" in data.get("tests", {}):
            e = data["tests"]["entities"]
            tokens = e.get("tokens", e.get("output_tokens", "-"))
            print(f"{model:<20} {e['entity_count']:>10} {e['duration_ms']:>12.0f} {tokens:>10}")

    # Taxonomy comparison
    print("\n### 2. TAXONOMY CLASSIFICATION ###\n")
    for model, data in results.items():
        if "taxonomy" in data.get("tests", {}):
            t = data["tests"]["taxonomy"]
            print(f"{model}: {', '.join(t['categories'])} (Konfidenz: {t['confidence']})")

    # Question comparison
    print("\n### 3. QUESTION GENERATION ###\n")
    for model, data in results.items():
        if "questions" in data.get("tests", {}):
            q = data["tests"]["questions"]
            print(f"\n{model} ({q['question_count']} Fragen, {q['duration_ms']:.0f}ms):")
            for i, question in enumerate(q.get("questions", [])[:3], 1):
                print(f"  {i}. {question.get('question', '?')}")
                print(f"     → {question.get('answer', '?')[:80]}...")

    # Summary
    print("\n### ZUSAMMENFASSUNG ###\n")

    summary = []
    for model, data in results.items():
        tests = data.get("tests", {})
        total_time = sum(t.get("duration_ms", 0) for t in tests.values())
        total_entities = tests.get("entities", {}).get("entity_count", 0)
        total_questions = tests.get("questions", {}).get("question_count", 0)
        categories = len(tests.get("taxonomy", {}).get("categories", []))
        summary.append(
            {
                "model": model,
                "time_ms": total_time,
                "entities": total_entities,
                "questions": total_questions,
                "categories": categories,
            }
        )

    print(f"{'Modell':<20} {'Gesamt-Zeit':>12} {'Entitäten':>10} {'Fragen':>8} {'Kategorien':>12}")
    print("-" * 65)
    for s in summary:
        print(f"{s['model']:<20} {s['time_ms']:>10.0f}ms {s['entities']:>10} {s['questions']:>8} {s['categories']:>12}")


if __name__ == "__main__":
    doc_id = int(sys.argv[1]) if len(sys.argv) > 1 else 2
    print(f"Starte Qualitätstest für Dokument {doc_id}...\n")

    results = run_quality_test(doc_id)
    print_report(results)
← Übersicht