relation_extractor.py

Code Hygiene Score: 80

Keine Issues gefunden.

Dependencies 13

Funktionen 1

Code

"""
Relation Extraction - Extract relations between entities.
"""

import json
import re
import sys
import time

import requests

sys.path.insert(0, "/var/www/scripts/pipeline")

from config import ANTHROPIC_MODEL, OLLAMA_CHAT_MODEL, OLLAMA_HOST
from constants import BATCH_LIMIT, LLM_TIMEOUT, MS_PER_SECOND
from db import db
from protokoll import protokoll


def extract_relations(text: str, entities: list[dict], client=None) -> list[dict]:
    """Extract relations between entities."""
    if not entities or len(entities) < 2:
        return []

    entity_names = [e["name"] for e in entities[:20]]

    prompt_template = db.get_prompt("relation_extraction")

    if not prompt_template:
        db.log("WARNING", "relation_extraction prompt not found in DB, using fallback")
        prompt_template = """Identifiziere Beziehungen zwischen Entitäten.
Entitäten: {{ENTITIES}}
Beziehungstypen: DEVELOPED_BY, RELATED_TO, PART_OF, USED_IN, BASED_ON
Antworte NUR im JSON-Format:
{"relations": [{"source": "...", "relation": "...", "target": "..."}]}

Text:
{{TEXT}}"""

    prompt = prompt_template.replace("{{ENTITIES}}", ", ".join(entity_names))
    prompt = prompt.replace("{{TEXT}}", text[:3000])

    try:
        start_time = time.time()
        tokens_in, tokens_out = 0, 0
        model_name = ""

        if client:
            message = client.messages.create(
                model=ANTHROPIC_MODEL, max_tokens=BATCH_LIMIT, messages=[{"role": "user", "content": prompt}]
            )
            response_text = message.content[0].text
            tokens_in = message.usage.input_tokens
            tokens_out = message.usage.output_tokens
            model_name = ANTHROPIC_MODEL
        else:
            response = requests.post(
                f"{OLLAMA_HOST}/api/generate",
                json={"model": OLLAMA_CHAT_MODEL, "prompt": prompt, "stream": False, "format": "json"},
                timeout=LLM_TIMEOUT,
            )
            response.raise_for_status()
            data = response.json()
            response_text = data.get("response", "{}")
            tokens_in = data.get("prompt_eval_count", 0)
            tokens_out = data.get("eval_count", 0)
            model_name = f"ollama:{OLLAMA_CHAT_MODEL}"

        duration_ms = int((time.time() - start_time) * MS_PER_SECOND)

        protokoll.log_llm_call(
            request=f"[relation_extraction] {prompt[:500]}...",
            response=response_text[:2000],
            model_name=model_name,
            tokens_input=tokens_in,
            tokens_output=tokens_out,
            duration_ms=duration_ms,
            status="completed",
        )

        json_match = re.search(r"\{[\s\S]*\}", response_text)
        if json_match:
            data = json.loads(json_match.group())
            return data.get("relations", [])
        return []
    except Exception as e:
        db.log("ERROR", f"Relation extraction failed: {e}")
        protokoll.log_llm_call(
            request=f"[relation_extraction] {prompt[:500]}...",
            model_name=ANTHROPIC_MODEL if client else f"ollama:{OLLAMA_CHAT_MODEL}",
            status="error",
            error_message=str(e),
        )
        return []
← Übersicht