# rag/vectorstore.py
import os
import json
import faiss
import numpy as np
from typing import List, Dict


REQUIRED_META_FIELDS = {
    "tenant_id",
    "doc_id",
    "doc_name",
    "source_url",
    "page_number",
    "section_heading",
    "chunk_id",
    "chunk_order",
    "text"
}


class FaissStore:
    """
    Enterprise-grade FAISS vector store with
    strict metadata validation (POC compliant).
    """

    def __init__(self, dim: int):
        self.dim = dim
        self.index = faiss.IndexFlatIP(dim)  # cosine similarity
        self.meta: List[Dict] = []
    # Vector normalization
    # -------------------------------
    @staticmethod
    def _normalize(vectors: np.ndarray) -> np.ndarray:
        norm = np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-12
        return vectors / norm
    # Metadata validation (CRITICAL)
    # -------------------------------
    @staticmethod
    def _validate_metadata(meta: Dict):
        missing = REQUIRED_META_FIELDS - set(meta.keys())
        if missing:
            raise ValueError(
                f"Missing required metadata fields: {missing}"
            )

    # -------------------------------
    # Add vectors + metadata
    # -------------------------------
    def add(self, vectors: np.ndarray, metadatas: List[Dict]):
        if len(vectors) != len(metadatas):
            raise ValueError("Vectors and metadata length mismatch")

        vectors = self._normalize(vectors.astype("float32"))

        for m in metadatas:
            self._validate_metadata(m)

        self.index.add(vectors)
        self.meta.extend(metadatas)

    # -------------------------------
    # Search
    # -------------------------------
    def search(self, query_vec: np.ndarray, k: int = 8):
        query_vec = self._normalize(query_vec.astype("float32"))
        scores, indices = self.index.search(query_vec, k)

        results = []
        for score, idx in zip(scores[0], indices[0]):
            if idx == -1:
                continue
            results.append({
                "score": float(score),
                "metadata": self.meta[idx]
            })
        return results

    # -------------------------------
    # Persist index + metadata
    # -------------------------------
    def save(self, path: str):
        os.makedirs(path, exist_ok=True)

        faiss.write_index(
            self.index,
            os.path.join(path, "index.faiss")
        )

        with open(os.path.join(path, "meta.json"), "w", encoding="utf-8") as f:
            json.dump(
                {
                    "dim": self.dim,
                    "count": len(self.meta),
                    "meta": self.meta
                },
                f,
                indent=2,
                ensure_ascii=False
            )
    # Load index
    # -------------------------------
    @classmethod
    def load(cls, path: str):
        index_path = os.path.join(path, "index.faiss")
        meta_path = os.path.join(path, "meta.json")

        if not os.path.exists(index_path) or not os.path.exists(meta_path):
            raise FileNotFoundError("FAISS index or metadata not found")

        index = faiss.read_index(index_path)

        with open(meta_path, "r", encoding="utf-8") as f:
            obj = json.load(f)

        store = cls(dim=obj["dim"])
        store.index = index
        store.meta = obj["meta"]

        return store
