embed.py

Code Hygiene Score: 77

Keine Issues gefunden.

Dependencies 14

Funktionen 5

Code

"""
Embedding generation for KI-System Pipeline
Uses Ollama (mxbai-embed-large) for vector embeddings.
"""

import json
import re
import uuid

import requests

from config import EMBEDDING_DIMENSION, EMBEDDING_MODEL, OLLAMA_HOST, QDRANT_HOST, QDRANT_PORT
from constants import BATCH_LIMIT, DEFAULT_LIMIT, OLLAMA_TIMEOUT
from db import db

# Max chars for mxbai-embed model (512 token context, varies by content)
# Conservative limit to handle German compound words and special chars
MAX_EMBED_CHARS = 800


def get_embedding(text):
    """Get embedding vector from Ollama."""
    # Skip empty content
    if not text or not text.strip():
        return None

    # Collapse consecutive dots/periods (table of contents, etc.)
    text = re.sub(r"\.{3,}", "...", text)

    # Truncate if too long for model context
    if len(text) > MAX_EMBED_CHARS:
        text = text[:MAX_EMBED_CHARS]

    try:
        response = requests.post(
            f"{OLLAMA_HOST}/api/embeddings",
            json={"model": EMBEDDING_MODEL, "prompt": text},
            timeout=OLLAMA_TIMEOUT,
        )
        response.raise_for_status()
        data = response.json()
        return data.get("embedding")
    except Exception as e:
        db.log("ERROR", f"Embedding generation failed: {e}")
        return None


def store_in_qdrant(collection, point_id, vector, payload):
    """Store embedding in Qdrant."""
    try:
        response = requests.put(
            f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{collection}/points",
            json={"points": [{"id": point_id, "vector": vector, "payload": payload}]},
            headers={"Content-Type": "application/json"},
            timeout=30,
        )
        response.raise_for_status()
        return True
    except Exception as e:
        db.log("ERROR", f"Qdrant storage failed: {e}")
        return False


def embed_chunks(chunks, document_id, document_title, source_path, progress=None):
    """
    Generate embeddings for chunks and store in Qdrant.
    Returns number of successfully embedded chunks.
    """
    embedded_count = 0
    total_chunks = len(chunks)

    for i, chunk in enumerate(chunks):
        # Log every chunk for full visibility
        if progress:
            progress.add_log(f"Embed: Chunk {i + 1}/{total_chunks}...")

        # Generate embedding
        embedding = get_embedding(chunk["content"])

        if not embedding:
            db.log("WARNING", f"Failed to embed chunk {i} of document {document_id}")
            continue

        # Verify dimension
        if len(embedding) != EMBEDDING_DIMENSION:
            db.log("ERROR", f"Wrong embedding dimension: {len(embedding)} vs {EMBEDDING_DIMENSION}")
            continue

        # Generate UUID for Qdrant
        point_id = str(uuid.uuid4())

        # Prepare payload
        payload = {
            "document_id": document_id,
            "document_title": document_title,
            "chunk_index": i,
            "content": chunk["content"][:BATCH_LIMIT],  # Truncate for payload
            "heading_path": json.dumps(chunk.get("heading_path", [])),
            "source_path": source_path,
        }

        # Add any chunk metadata
        if chunk.get("metadata"):
            for key, value in chunk["metadata"].items():
                if isinstance(value, (str, int, float, bool)):
                    payload[key] = value

        # Store in Qdrant
        if store_in_qdrant("documents", point_id, embedding, payload):
            # Update chunk in database with Qdrant ID
            db.update_chunk_qdrant_id(chunk.get("db_id"), point_id)
            embedded_count += 1
            db.log("INFO", f"Embedded chunk {i}/{len(chunks)}", f"doc={document_id}")
        else:
            db.log("ERROR", f"Failed to store chunk {i} in Qdrant")

    return embedded_count


def embed_pending_chunks(limit=DEFAULT_LIMIT):
    """Process chunks that haven't been embedded yet."""
    db.connect()

    try:
        chunks = db.get_chunks_for_embedding(limit)
        db.log("INFO", f"Found {len(chunks)} chunks to embed")

        if not chunks:
            return 0

        embedded = 0
        for chunk in chunks:
            embedding = get_embedding(chunk["content"])

            if not embedding:
                continue

            point_id = str(uuid.uuid4())

            # Get document info
            cursor = db.execute("SELECT filename, source_path FROM documents WHERE id = %s", (chunk["document_id"],))
            doc = cursor.fetchone()
            cursor.close()

            payload = {
                "document_id": chunk["document_id"],
                "document_title": doc["filename"] if doc else "",
                "chunk_id": chunk["id"],
                "content": chunk["content"][:BATCH_LIMIT],
                "source_path": doc["source_path"] if doc else "",
            }

            if store_in_qdrant("documents", point_id, embedding, payload):
                db.update_chunk_qdrant_id(chunk["id"], point_id)
                embedded += 1

        db.log("INFO", f"Embedded {embedded}/{len(chunks)} chunks")
        return embedded

    except Exception as e:
        db.log("ERROR", f"Embedding error: {e}")
        raise
    finally:
        db.disconnect()


def search_similar(query, collection="documents", limit=5):
    """Search for similar documents in Qdrant."""
    # Get query embedding
    embedding = get_embedding(query)
    if not embedding:
        return []

    try:
        response = requests.post(
            f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{collection}/points/search",
            json={"vector": embedding, "limit": limit, "with_payload": True},
            headers={"Content-Type": "application/json"},
            timeout=30,
        )
        response.raise_for_status()
        data = response.json()
        return data.get("result", [])
    except Exception as e:
        db.log("ERROR", f"Qdrant search failed: {e}")
        return []


if __name__ == "__main__":
    import sys

    if len(sys.argv) > 1:
        query = " ".join(sys.argv[1:])
        print(f"Searching for: {query}")
        print("-" * 50)

        results = search_similar(query)
        for i, result in enumerate(results):
            print(f"\n[{i + 1}] Score: {result['score']:.4f}")
            print(f"    Document: {result['payload'].get('document_title', 'Unknown')}")
            print(f"    Content: {result['payload'].get('content', '')[:200]}...")
    else:
        # Run pending embeddings
        count = embed_pending_chunks()
        print(f"Embedded {count} chunks")
← Übersicht