entity_normalizer.py

Code Hygiene Score: 100

Keine Issues gefunden.

Dependencies 3

Klassen 1

Funktionen 3

Code

"""
Entity Type Normalizer
Deterministic rules for entity type assignment.
Reads rules from config/entity_type_rules.yaml.
"""

import re
from pathlib import Path

import yaml


class EntityNormalizer:
    """Normalizes entity types based on deterministic rules."""

    def __init__(self, rules_path: str | None = None):
        if rules_path is None:
            rules_path = Path(__file__).parent.parent / "config" / "entity_type_rules.yaml"

        self.rules_path = Path(rules_path)
        self.rules = self._load_rules()

        # Build lookup structures
        self._explicit_map: dict[str, str] = {}
        self._pattern_rules: list[tuple[re.Pattern, str]] = []
        self._stopwords: set[str] = set()
        self._default_type = "CONCEPT"

        self._build_lookups()

    def _load_rules(self) -> dict:
        """Load rules from YAML file."""
        if not self.rules_path.exists():
            return {}

        with open(self.rules_path, encoding="utf-8") as f:
            return yaml.safe_load(f) or {}

    def _build_lookups(self) -> None:
        """Build efficient lookup structures from rules."""
        # Explicit mappings (case-insensitive lookup)
        for entity_type, names in self.rules.get("explicit_mappings", {}).items():
            for name in names:
                self._explicit_map[name.lower()] = entity_type

        # Pattern rules (compile regexes)
        for entity_type, patterns in self.rules.get("pattern_rules", {}).items():
            for pattern in patterns:
                try:
                    compiled = re.compile(pattern, re.IGNORECASE)
                    self._pattern_rules.append((compiled, entity_type))
                except re.error:
                    pass

        # Stopwords
        self._stopwords = {w.lower() for w in self.rules.get("stopwords", [])}

        # Default type
        self._default_type = self.rules.get("default_type", "CONCEPT")

    def is_stopword(self, name: str) -> bool:
        """Check if entity name is a stopword."""
        return name.lower() in self._stopwords

    def normalize_type(self, name: str, llm_type: str | None = None) -> str:
        """
        Determine the correct type for an entity.

        Priority:
        1. Explicit mapping (highest)
        2. Pattern rules
        3. LLM suggestion (if valid)
        4. Default type
        """
        name_lower = name.lower()

        # 1. Check explicit mapping
        if name_lower in self._explicit_map:
            return self._explicit_map[name_lower]

        # 2. Check pattern rules
        for pattern, entity_type in self._pattern_rules:
            if pattern.search(name):
                return entity_type

        # 3. Use LLM type if valid
        valid_types = {
            "PERSON",
            "ROLE",
            "ORGANIZATION",
            "LOCATION",
            "THEORY",
            "METHOD",
            "MODEL",
            "CONCEPT",
            "ARTIFACT",
            "METAPHOR",
            "PRINCIPLE",
            "TOOL",
            "EVENT",
            "OTHER",
        }
        if llm_type and llm_type.upper() in valid_types:
            return llm_type.upper()

        # 4. Default
        return self._default_type

    def normalize_entity(self, entity: dict) -> dict | None:
        """
        Normalize a single entity.
        Returns None if entity should be filtered (stopword).
        """
        name = entity.get("name", "")

        if not name or len(name) < 3:
            return None

        if self.is_stopword(name):
            return None

        llm_type = entity.get("type")
        normalized_type = self.normalize_type(name, llm_type)

        return {
            "name": name,
            "type": normalized_type,
            "description": entity.get("description"),
        }

    def normalize_entities(self, entities: list[dict]) -> list[dict]:
        """Normalize a list of entities, filtering stopwords."""
        result = []
        for entity in entities:
            normalized = self.normalize_entity(entity)
            if normalized:
                result.append(normalized)
        return result


# Singleton instance
_normalizer: EntityNormalizer | None = None


def get_normalizer() -> EntityNormalizer:
    """Get or create the singleton normalizer instance."""
    global _normalizer
    if _normalizer is None:
        _normalizer = EntityNormalizer()
    return _normalizer


def normalize_entity_type(name: str, llm_type: str | None = None) -> str:
    """Convenience function to normalize a single entity type."""
    return get_normalizer().normalize_type(name, llm_type)


def normalize_entities(entities: list[dict]) -> list[dict]:
    """Convenience function to normalize a list of entities."""
    return get_normalizer().normalize_entities(entities)


if __name__ == "__main__":
    # Test
    normalizer = EntityNormalizer()

    test_cases = [
        {"name": "Coach", "type": "PERSON"},
        {"name": "Klient", "type": "PERSON"},
        {"name": "Steve de Shazer", "type": "PERSON"},
        {"name": "Wunderfrage", "type": "CONCEPT"},
        {"name": "Systemische Therapie", "type": "CONCEPT"},
        {"name": "GROW-Modell", "type": "CONCEPT"},
        {"name": "Reframing", "type": "CONCEPT"},
        {"name": "Aspekte", "type": "CONCEPT"},  # Should be filtered
    ]

    print("Entity Normalizer Test:")
    print("-" * 50)
    for entity in test_cases:
        result = normalizer.normalize_entity(entity)
        if result:
            print(f"{entity['name']:30} {entity['type']:10} -> {result['type']}")
        else:
            print(f"{entity['name']:30} FILTERED (stopword)")
← Übersicht Graph