entity_extractor.py

Code Hygiene Score: 45

Issues 2

Zeile Typ Beschreibung
253 magic_number Magic Number gefunden: 1000
- complexity Datei hat 503 Zeilen (max: 500)

Dependencies 14

Funktionen 16

Code

"""
Entity Extraction - Extract and store entities from text.
"""

import json
import re
import sys
import time
import unicodedata

import requests

sys.path.insert(0, "/var/www/scripts/pipeline")

from config import ANTHROPIC_MODEL, OLLAMA_HOST
from constants import LLM_MAX_TOKENS, LLM_TIMEOUT, MS_PER_SECOND, PROMPT_TEXT_LIMIT
from db import db
from protokoll import protokoll


def _build_prompt_from_yaml(yaml_content: str, text: str) -> str:
    """Pass YAML prompt directly to LLM with text placeholder replaced."""
    return yaml_content.replace("{{TEXT}}", text[:PROMPT_TEXT_LIMIT])


def normalize_name(name: str) -> str:
    """Generate canonical_name from entity name.

    Rules:
    - Lowercase
    - German umlauts: ä→ae, ö→oe, ü→ue, ß→ss
    - Replace spaces with underscores
    - Remove special characters except underscores
    - Collapse multiple underscores
    """
    if not name:
        return ""

    result = name.lower()

    replacements = {
        "ä": "ae",
        "ö": "oe",
        "ü": "ue",
        "ß": "ss",
        "Ä": "ae",
        "Ö": "oe",
        "Ü": "ue",
    }
    for old, new in replacements.items():
        result = result.replace(old, new)

    result = unicodedata.normalize("NFKD", result)
    result = result.encode("ascii", "ignore").decode("ascii")
    result = re.sub(r"[\s\-]+", "_", result)
    result = re.sub(r"[^a-z0-9_]", "", result)
    result = re.sub(r"_+", "_", result)
    result = result.strip("_")

    return result


# Category to type mapping - loaded dynamically from DB
def _get_category_type_map() -> dict[str, str]:
    """Build category mapping from entity_types table."""
    types = db.get_entity_types()
    mapping = {}
    for t in types:
        # Map plural lowercase to uppercase code
        plural = t["code"].lower() + "s"
        mapping[plural] = t["code"]
        # Also map singular
        mapping[t["code"].lower()] = t["code"]
    return mapping


def _get_valid_type_codes() -> set[str]:
    """Get valid entity type codes from DB."""
    return db.get_entity_type_codes()


# Stopword cache
_stopword_cache: set[str] | None = None


def _get_stopwords() -> set[str]:
    """Get stopwords from DB with caching."""
    global _stopword_cache
    if _stopword_cache is None:
        _stopword_cache = set(db.get_stopwords())
    return _stopword_cache


def _is_stopword(entity_name: str) -> bool:
    """Check if entity is a stopword (should be filtered out)."""
    canonical = normalize_name(entity_name)
    stopwords = _get_stopwords()
    return canonical in stopwords


def _validate_entity_in_text(entity_name: str, source_text: str) -> bool:
    """Strictly validate that entity appears EXACTLY in source text."""
    if not entity_name or len(entity_name) < 3:
        return False
    # Exact match required
    return entity_name in source_text


def _normalize_entity_response(result: dict, source_text: str) -> list[dict]:
    """Normalize entity response to standard format with validation.

    Handles two formats:
    1. New: {"persons":[], "roles":[], ...}
    2. Legacy: {"entities": [...]}

    Also validates entities against source text to filter hallucinations.
    """
    entities = []

    # Check for legacy format
    if "entities" in result:
        legacy_entities = result.get("entities", [])
        # Validate legacy entities too
        for e in legacy_entities:
            if isinstance(e, dict) and "name" in e and _validate_entity_in_text(e["name"], source_text):
                entities.append(e)
        return entities

    # New categorized format
    category_map = _get_category_type_map()
    for category, items in result.items():
        if not isinstance(items, list):
            continue

        entity_type = category_map.get(category.lower(), category.upper())

        for item in items:
            if not item or not isinstance(item, str):
                continue

            # Strict validation: entity must appear EXACTLY in source text
            if not _validate_entity_in_text(item, source_text):
                continue  # Skip hallucinations

            entities.append(
                {
                    "name": item,
                    "type": entity_type,
                    "description": None,
                }
            )

    return entities


def _build_pass2_categories() -> str:
    """Build categories section for pass2 prompt from entity_types table."""
    types = db.get_entity_types()
    lines = []
    for t in types:
        lines.append(f"  {t['code']}: {t['criteria']}")
    return "\n".join(lines)


def _call_ollama(prompt: str, model: str, timeout: int = LLM_TIMEOUT) -> tuple[str, int, int, int]:
    """Call Ollama API and return (response_text, tokens_in, tokens_out, duration_ms)."""
    start_time = time.time()
    response = requests.post(
        f"{OLLAMA_HOST}/api/generate",
        json={"model": model, "prompt": prompt, "stream": False, "format": "json"},
        timeout=timeout,
    )
    response.raise_for_status()
    data = response.json()
    duration_ms = int((time.time() - start_time) * MS_PER_SECOND)
    return (
        data.get("response", "{}"),
        data.get("prompt_eval_count", 0),
        data.get("eval_count", 0),
        duration_ms,
    )


def extract_entities_ollama(text: str, model: str = "gemma3:27b-it-qat") -> list[dict]:
    """Extract entities using 2-pass approach for better categorization.

    Pass 1: Extract entity names from text
    Pass 2: Categorize extracted entities
    Post: Normalize types using deterministic rules

    Falls back to single-pass if 2-pass prompts not available.
    """
    # Try 2-pass approach first
    pass1_template = db.get_prompt("entity_extraction_pass1")
    pass2_template = db.get_prompt("entity_extraction_pass2")

    if pass1_template and pass2_template:
        entities = _extract_entities_2pass(text, pass1_template, pass2_template, model)
    else:
        # Fallback to single-pass
        entities = _extract_entities_single_pass(text, model)

    return entities


def _extract_entities_2pass(text: str, pass1_template: str, pass2_template: str, model: str) -> list[dict]:
    """2-pass entity extraction: extract then categorize."""
    try:
        # PASS 1: Extract entity names
        prompt1 = pass1_template.replace("{text}", text[:PROMPT_TEXT_LIMIT])
        resp1, tok_in1, tok_out1, dur1 = _call_ollama(prompt1, model)

        try:
            result1 = json.loads(resp1)
            raw_entities = result1.get("entities", [])
        except json.JSONDecodeError:
            db.log("WARNING", "Failed to parse Pass 1 JSON")
            return []

        # Validate: only keep entities that appear in text and are not stopwords
        valid_entities = [e for e in raw_entities if _validate_entity_in_text(e, text) and not _is_stopword(e)]

        if not valid_entities:
            return []

        protokoll.log_llm_call(
            request=f"[entity_extraction_pass1] {len(valid_entities)} entities",
            response=json.dumps(valid_entities[:10], ensure_ascii=False),
            model_name=f"ollama:{model}",
            tokens_input=tok_in1,
            tokens_output=tok_out1,
            duration_ms=dur1,
            status="completed",
        )

        # PASS 2: Categorize entities (with dynamic categories from DB)
        entities_json = json.dumps(valid_entities, ensure_ascii=False)
        categories_text = _build_pass2_categories()
        prompt2 = pass2_template.replace("{categories}", categories_text)
        prompt2 = prompt2.replace("{entities}", entities_json)
        resp2, tok_in2, tok_out2, dur2 = _call_ollama(prompt2, model)

        try:
            result2 = json.loads(resp2)
            categorized = result2.get("kategorisiert", [])
        except json.JSONDecodeError:
            db.log("WARNING", "Failed to parse Pass 2 JSON")
            # Fallback: return uncategorized entities
            return [{"name": e, "type": "CONCEPT", "description": None} for e in valid_entities]

        protokoll.log_llm_call(
            request=f"[entity_extraction_pass2] categorize {len(valid_entities)} entities",
            response=resp2[:1000],
            model_name=f"ollama:{model}",
            tokens_input=tok_in2,
            tokens_output=tok_out2,
            duration_ms=dur2,
            status="completed",
        )

        # Normalize output (validate types against DB)
        valid_types = _get_valid_type_codes()
        entities = []
        for e in categorized:
            if isinstance(e, dict) and "name" in e and "type" in e:
                # Final validation: in text, not stopword
                name = e["name"]
                if _validate_entity_in_text(name, text) and not _is_stopword(name):
                    entity_type = e["type"].upper()
                    # Fallback to CONCEPT if type not in DB
                    if entity_type not in valid_types:
                        entity_type = "CONCEPT"
                    entities.append(
                        {
                            "name": name,
                            "type": entity_type,
                            "description": e.get("description"),
                        }
                    )

        return entities

    except Exception as e:
        db.log("ERROR", f"2-pass entity extraction failed: {e}")
        return []


def _extract_entities_single_pass(text: str, model: str) -> list[dict]:
    """Single-pass entity extraction (legacy fallback)."""
    prompt_data = db.get_prompt_by_use_case("entity_extraction")
    prompt_content = prompt_data["content"] if prompt_data else None

    if not prompt_content:
        db.log("WARNING", "entity_extraction prompt not found in DB, using fallback")
        prompt_content = """Analysiere den Text und extrahiere wichtige Entitäten.
Kategorisiere als: PERSON, ORGANIZATION, CONCEPT, LOCATION
Antworte NUR im JSON-Format:
{"entities": [{"name": "...", "type": "...", "description": "..."}]}

Text:
{{TEXT}}"""

    # Build prompt from YAML or plain text
    prompt = _build_prompt_from_yaml(prompt_content, text)

    try:
        resp, tok_in, tok_out, dur = _call_ollama(prompt, model)

        protokoll.log_llm_call(
            request=f"[entity_extraction] {prompt[:500]}...",
            response=resp[:2000],
            model_name=f"ollama:{model}",
            tokens_input=tok_in,
            tokens_output=tok_out,
            duration_ms=dur,
            status="completed",
        )

        try:
            result = json.loads(resp)
            return _normalize_entity_response(result, text)
        except json.JSONDecodeError:
            db.log("WARNING", "Failed to parse entity JSON from Ollama")
            return []
    except Exception as e:
        db.log("ERROR", f"Ollama entity extraction failed: {e}")
        return []


def extract_entities_anthropic(text: str, client) -> list[dict]:
    """Extract entities using Anthropic Claude."""
    prompt_data = db.get_prompt_by_use_case("entity_extraction")
    prompt_content = prompt_data["content"] if prompt_data else None

    if not prompt_content:
        prompt_content = """Analysiere den folgenden deutschen Text und extrahiere alle wichtigen Entitäten.

Kategorisiere jede Entität als:
- PERSON (Namen von Personen)
- ORGANIZATION (Firmen, Institutionen, Gruppen)
- CONCEPT (Fachbegriffe, Methoden, Theorien)
- LOCATION (Orte, Länder)
- DATE (Zeitangaben)
- OTHER (Sonstiges)

Antworte NUR im JSON-Format:
{"entities": [{"name": "...", "type": "...", "context": "kurzer Kontext der Erwähnung"}]}

Text:
{{TEXT}}"""

    # Build prompt from YAML or plain text
    prompt = _build_prompt_from_yaml(prompt_content, text[:PROMPT_TEXT_LIMIT])

    try:
        start_time = time.time()
        message = client.messages.create(
            model=ANTHROPIC_MODEL, max_tokens=LLM_MAX_TOKENS, messages=[{"role": "user", "content": prompt}]
        )
        duration_ms = int((time.time() - start_time) * MS_PER_SECOND)

        response_text = message.content[0].text

        protokoll.log_llm_call(
            request=f"[entity_extraction] {prompt[:500]}...",
            response=response_text[:2000],
            model_name=ANTHROPIC_MODEL,
            tokens_input=message.usage.input_tokens,
            tokens_output=message.usage.output_tokens,
            duration_ms=duration_ms,
            status="completed",
        )

        json_match = re.search(r"\{[\s\S]*\}", response_text)
        if json_match:
            entities = json.loads(json_match.group())
            return entities.get("entities", [])
        return []
    except Exception as e:
        db.log("ERROR", f"Anthropic entity extraction failed: {e}")
        protokoll.log_llm_call(
            request=f"[entity_extraction] {prompt[:500]}...",
            model_name=ANTHROPIC_MODEL,
            status="error",
            error_message=str(e),
        )
        return []


def store_entities(document_id: int, entities: list[dict]) -> int:
    """Store extracted entities in database with deduplication via canonical_name."""
    stored = 0

    for entity in entities:
        try:
            description = entity.get("description") or entity.get("context") or None
            canonical = normalize_name(entity["name"])

            # Check for existing entity by canonical_name first (deduplication)
            cursor = db.execute(
                "SELECT id, description FROM entities WHERE canonical_name = %s AND type = %s",
                (canonical, entity["type"]),
            )
            existing = cursor.fetchone()
            cursor.close()

            # Fallback: check by exact name
            if not existing:
                cursor = db.execute(
                    "SELECT id, description FROM entities WHERE name = %s AND type = %s",
                    (entity["name"], entity["type"]),
                )
                existing = cursor.fetchone()
                cursor.close()

            if existing:
                entity_id = existing["id"]
                # Update description and canonical_name if missing
                if description and not existing["description"]:
                    cursor = db.execute(
                        "UPDATE entities SET description = %s, canonical_name = %s WHERE id = %s",
                        (description, canonical, entity_id),
                    )
                    db.commit()
                    cursor.close()
                else:
                    # Ensure canonical_name is set
                    cursor = db.execute(
                        "UPDATE entities SET canonical_name = %s WHERE id = %s AND canonical_name IS NULL",
                        (canonical, entity_id),
                    )
                    db.commit()
                    cursor.close()
            else:
                # Insert new entity with canonical_name and status='normalized'
                cursor = db.execute(
                    """INSERT INTO entities (name, type, description, canonical_name, status, created_at)
                       VALUES (%s, %s, %s, %s, 'normalized', NOW())""",
                    (entity["name"], entity["type"], description, canonical),
                )
                db.commit()
                entity_id = cursor.lastrowid
                cursor.close()

                db.log_provenance(
                    artifact_type="entity",
                    artifact_id=entity_id,
                    source_type="extraction",
                    source_id=document_id,
                    pipeline_step="entity_extract",
                )

            # Link entity to document via document_entities
            try:
                cursor = db.execute(
                    """INSERT IGNORE INTO document_entities (document_id, entity_id, relevance, created_at)
                       VALUES (%s, %s, 0.8, NOW())""",
                    (document_id, entity_id),
                )
                db.commit()
                cursor.close()
            except Exception as link_err:
                db.log("WARNING", f"Failed to link entity {entity_id} to document {document_id}: {link_err}")

            stored += 1

        except Exception as e:
            db.log("WARNING", f"Failed to store entity: {e}")

    return stored


def find_entity_by_name(name: str) -> dict | None:
    """Find entity by name or canonical_name."""
    name_lower = name.lower().strip()
    canonical = normalize_name(name)

    cursor = db.execute(
        "SELECT id, name, canonical_name FROM entities WHERE canonical_name = %s LIMIT 1",
        (canonical,),
    )
    result = cursor.fetchone()
    cursor.close()
    if result:
        return result

    cursor = db.execute(
        "SELECT id, name, canonical_name FROM entities WHERE LOWER(name) = %s LIMIT 1",
        (name_lower,),
    )
    result = cursor.fetchone()
    cursor.close()
    if result:
        return result

    cursor = db.execute(
        "SELECT id, name, canonical_name FROM entities WHERE canonical_name LIKE %s LIMIT 1",
        (f"%{canonical}%",),
    )
    result = cursor.fetchone()
    cursor.close()
    return result
← Übersicht