# rag/qa.py
import json
import os
from typing import List, Dict
import numpy as np
from openai import OpenAI

from rag.config import (
    OPENAI_API_KEY,
    EMBED_MODEL,
    LLM_MODEL,
    INDEX_DIR,
    ARTIFACTS_DIR,
    CATEGORIES,
    TOP_K_PER_CATEGORY,
    MAX_FINDINGS_PER_CATEGORY,
    RUN_ID,
    INPUT_DIR
)
from rag.vectorstore import FaissStore

client = OpenAI(api_key=OPENAI_API_KEY)


# -------------------------------------------------
# Embedding
# -------------------------------------------------
def embed_query(query: str) -> np.ndarray:
    """Convert a text query into an embedding vector."""
    resp = client.embeddings.create(
        model=EMBED_MODEL,
        input=[query]
    )
    return np.array([resp.data[0].embedding], dtype="float32")


# -------------------------------------------------
# Answer question with LLM (for main.py ask command)
# -------------------------------------------------
def answer_question_with_llm(question: str, search_results: List[Dict]) -> str:
    """
    Generate an LLM answer based on retrieved chunks.
    Used by main.py ask command.
    """
    if not search_results:
        return "No relevant information found in the documents."
    
    # Build context from search results
    context_parts = []
    for r in search_results:
        meta = r.get("metadata", {})
        doc_name = meta.get("doc_name", "Unknown")
        page = meta.get("page_number", "?")
        text = meta.get("text", "")
        context_parts.append(f"[{doc_name}, Page {page}]\n{text}")
    
    context = "\n\n".join(context_parts[:15])  # Limit to top 15 chunks
    
    prompt = f"""You are an AI assistant analyzing annual reports and financial documents.

Question: {question}

Context from documents:
{context[:12000]}

Instructions:
- Provide a clear, factual answer based ONLY on the context above
- Cite sources using [Document Name, Page X] format after each claim
- If the information is not in the context, say "Information not available in the provided documents"
- Be concise but complete

Answer:"""
    
    try:
        response = client.chat.completions.create(
            model=LLM_MODEL,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.2,
            max_tokens=600
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"Error generating answer: {str(e)}"


# -------------------------------------------------
# Evidence validation (STRICT)
# -------------------------------------------------
def validate_evidence(evidence: Dict, chunk_map: Dict[str, Dict]) -> bool:
    """
    Validate that evidence actually exists in the retrieved chunks.
    Checks: chunk_id exists, quote is in text, doc_name matches, page matches.
    """
    chunk = chunk_map.get(evidence.get("chunk_id"))
    if not chunk:
        return False

    quote = evidence.get("quote", "").lower()
    if not quote:
        return False

    if quote not in chunk["text"].lower():
        return False

    if evidence.get("doc_name") != chunk["doc_name"]:
        return False

    if evidence.get("page") != chunk["page_number"]:
        return False

    return True


# -------------------------------------------------
# Extract findings per category
# -------------------------------------------------
def extract_findings_per_category(category: str, query: str) -> List[Dict]:
    """
    For a given category (e.g., 'Risks', 'Management'), retrieve relevant chunks
    and use LLM to extract structured findings with evidence.
    """
    store = FaissStore.load(INDEX_DIR)
    qvec = embed_query(query)
    hits = store.search(qvec, k=TOP_K_PER_CATEGORY)

    bundle = []
    chunk_map = {}

    for h in hits:
        meta = h["metadata"]
        entry = {
            "chunk_id": meta["chunk_id"],
            "doc_name": meta["doc_name"],
            "page_number": meta["page_number"],
            "source_url": meta["source_url"],
            "text": meta["text"]
        }
        bundle.append(entry)
        chunk_map[meta["chunk_id"]] = entry

    # Save retrieval bundle for debugging
    bundle_path = os.path.join(
        ARTIFACTS_DIR,
        f"retrieval_bundle_{category.lower().replace(' ', '_').replace('/', '_')}.json"
    )
    os.makedirs(ARTIFACTS_DIR, exist_ok=True)
    with open(bundle_path, "w", encoding="utf-8") as f:
        json.dump(bundle, f, indent=2)

    prompt = f"""
You are a pharma-grade annual report extraction engine.

Category: {category}

RULES:
- Return VALID JSON ONLY
- Use ONLY the provided chunks
- If unsupported, OMIT the finding
- Each finding must include evidence with doc_name, page, chunk_id, quote (≤25 words)

CHUNKS:
{json.dumps(bundle, indent=2)}

OUTPUT SCHEMA:
{{
  "category": "{category}",
  "findings": [
    {{
      "finding": "Short description",
      "severity": "High|Medium|Low",
      "confidence": "High|Medium|Low",
      "evidence": [
        {{
          "doc_name": "...",
          "page": 123,
          "chunk_id": "...",
          "quote": "exact quote from text"
        }}
      ]
    }}
  ]
}}
"""

    try:
        resp = client.chat.completions.create(
            model=LLM_MODEL,
            messages=[{"role": "user", "content": prompt}],
            temperature=0
        )
        result = json.loads(resp.choices[0].message.content)
    except (json.JSONDecodeError, Exception) as e:
        print(f"⚠️ Error extracting findings for {category}: {e}")
        return []

    findings = []
    for f in result.get("findings", [])[:MAX_FINDINGS_PER_CATEGORY]:
        valid_evidence = [
            e for e in f.get("evidence", [])
            if validate_evidence(e, chunk_map)
        ]

        if valid_evidence:
            f["evidence"] = valid_evidence
            findings.append(f)

    return findings


# -------------------------------------------------
# Generate findings.json (Stage 1 output)
# -------------------------------------------------
def generate_findings_json() -> str:
    """
    Generate the master findings.json file by:
    1. Extracting findings for each category
    2. Scoring and ranking findings
    3. Identifying top 10 critical findings
    """
    all_findings = []
    fid = 1

    print("\n🔍 Extracting findings per category...\n")
    
    for category, query in CATEGORIES.items():
        print(f"  Processing: {category}...")
        findings = extract_findings_per_category(category, query)
        for f in findings:
            f["id"] = f"F-{fid:03d}"
            f["category"] = category
            fid += 1
            all_findings.append(f)
        print(f"    → Found {len(findings)} findings")

    # Calculate recurrence and score
    for f in all_findings:
        doc_set = {e["doc_name"] for e in f["evidence"]}
        f["recurrence_count"] = len(doc_set)
        sev_score = {"High": 3, "Medium": 2, "Low": 1}.get(f.get("severity", "Low"), 1)
        f["score"] = sev_score * (1 + f["recurrence_count"])

    # Sort by score and pick top 10
    all_findings.sort(key=lambda x: x["score"], reverse=True)
    top_10_ids = [f["id"] for f in all_findings[:10]]

    findings_data = {
        "run_id": RUN_ID,
        "input_folder": INPUT_DIR,
        "total_findings": len(all_findings),
        "findings": all_findings,
        "top_10_ids": top_10_ids
    }

    os.makedirs(ARTIFACTS_DIR, exist_ok=True)
    out_path = os.path.join(ARTIFACTS_DIR, "findings.json")
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(findings_data, f, indent=2, ensure_ascii=False)

    print(f"\n✅ Total findings: {len(all_findings)}")
    print(f"✅ Top 10 critical findings: {', '.join(top_10_ids)}")
    print(f"✅ findings.json saved at: {out_path}\n")
    
    return out_path


if __name__ == "__main__":
    path = generate_findings_json()
    print(f"✅ findings.json generated at {path}")
