# rag/ingest.py - Complete Incremental Ingestion System
"""
Smart PDF ingestion system with incremental processing.
Only processes new or modified files, preserving existing index.
"""

import os
import glob
import hashlib
import json
import numpy as np
from tqdm import tqdm
from openai import OpenAI
from pathlib import Path
from datetime import datetime

from rag.config import OPENAI_API_KEY, EMBED_MODEL, CHUNK_SIZE, CHUNK_OVERLAP, INDEX_DIR, TENANT_ID
from rag.pdf_utils import load_pdf_pages, extract_section_heading
from rag.vectorstore import FaissStore


# Initialize OpenAI client
client = OpenAI(api_key=OPENAI_API_KEY)


# Tracking file path
TRACKING_FILE = os.path.join(os.path.dirname(INDEX_DIR), "processed_files.json")


# ============================================================================
# FILE TRACKING & HASHING
# ============================================================================

def get_file_hash(filepath: str) -> str:
    """
    Calculate MD5 hash of file to detect changes.
    
    Args:
        filepath: Path to file
        
    Returns:
        MD5 hash as hex string
    """
    hash_md5 = hashlib.md5()
    try:
        with open(filepath, "rb") as f:
            for chunk in iter(lambda: f.read(8192), b""):
                hash_md5.update(chunk)
        return hash_md5.hexdigest()
    except Exception as e:
        print(f"⚠️  Warning: Could not hash {filepath}: {e}")
        return ""


def load_processed_files() -> dict:
    """
    Load record of already processed files from disk.
    
    Returns:
        Dict mapping filename -> {hash, processed_date, chunks, pages}
    """
    if os.path.exists(TRACKING_FILE):
        try:
            with open(TRACKING_FILE, "r", encoding="utf-8") as f:
                return json.load(f)
        except Exception as e:
            print(f"⚠️  Warning: Could not load tracking file: {e}")
            return {}
    return {}


def save_processed_files(processed: dict):
    """
    Save record of processed files to disk.
    
    Args:
        processed: Dict mapping filename -> metadata
    """
    try:
        os.makedirs(os.path.dirname(TRACKING_FILE), exist_ok=True)
        with open(TRACKING_FILE, "w", encoding="utf-8") as f:
            json.dump(processed, f, indent=2, ensure_ascii=False)
    except Exception as e:
        print(f"⚠️  Warning: Could not save tracking file: {e}")


# ============================================================================
# CHUNKING & EMBEDDING
# ============================================================================

def stable_chunk_id(doc_id: str, page: int, chunk_order: int, text: str) -> str:
    """
    Generate stable, deterministic chunk ID.
    
    Args:
        doc_id: Document identifier
        page: Page number
        chunk_order: Chunk sequence number
        text: Chunk text (first 50 chars used)
        
    Returns:
        SHA1 hash as hex string
    """
    base = f"{doc_id}|{page}|{chunk_order}|{text[:50]}"
    return hashlib.sha1(base.encode("utf-8")).hexdigest()


def chunk_text_generator(text: str, chunk_size: int, overlap: int):
    """
    Generate text chunks with overlap using generator (memory efficient).
    
    Args:
        text: Input text to chunk
        chunk_size: Size of each chunk in characters
        overlap: Overlap between consecutive chunks
        
    Yields:
        Text chunks as strings
    """
    # Normalize whitespace
    text = " ".join(text.split())
    
    if not text:
        return
    
    start = 0
    while start < len(text):
        end = min(len(text), start + chunk_size)
        yield text[start:end]
        
        # Stop if we've reached the end
        if end == len(text):
            break
        
        # Move start forward by (chunk_size - overlap)
        start = end - overlap
        
        # Ensure we make progress
        if start < 0:
            start = end


def embed_texts(texts: list[str]) -> np.ndarray:
    """
    Generate embeddings for text chunks using OpenAI API.
    
    Args:
        texts: List of text strings to embed
        
    Returns:
        Numpy array of shape (len(texts), embedding_dim)
    """
    if not texts:
        return np.zeros((0, 1536), dtype="float32")
    
    try:
        response = client.embeddings.create(
            model=EMBED_MODEL,
            input=texts
        )
        embeddings = [d.embedding for d in response.data]
        return np.array(embeddings, dtype="float32")
    except Exception as e:
        print(f"❌ Error generating embeddings: {e}")
        raise


# ============================================================================
# FILE ANALYSIS
# ============================================================================

def analyze_files_to_process(folder: str) -> tuple[list[str], list[str], dict]:
    """
    Analyze which files need processing.
    
    Args:
        folder: Path to folder containing PDFs
        
    Returns:
        Tuple of (files_to_process, skipped_files, processed_files_dict)
    """
    # Get all PDFs
    pdfs = sorted(glob.glob(os.path.join(folder, "*.pdf")))
    
    if not pdfs:
        raise ValueError(f"No PDF files found in {folder}")
    
    # Load tracking data
    processed_files = load_processed_files()
    
    # Categorize files
    files_to_process = []
    skipped_files = []
    
    for pdf_path in pdfs:
        file_name = os.path.basename(pdf_path)
        current_hash = get_file_hash(pdf_path)
        
        if not current_hash:
            # Hash failed, process the file to be safe
            files_to_process.append(pdf_path)
            continue
        
        # Check if file was processed before
        if file_name in processed_files:
            stored_hash = processed_files[file_name].get("hash", "")
            
            if current_hash == stored_hash:
                # File unchanged - skip it
                skipped_files.append(file_name)
            else:
                # File modified - reprocess it
                print(f"   🔄 Modified: {file_name}")
                files_to_process.append(pdf_path)
        else:
            # New file - process it
            print(f"   ✨ New: {file_name}")
            files_to_process.append(pdf_path)
    
    return files_to_process, skipped_files, processed_files


# ============================================================================
# INCREMENTAL INGESTION
# ============================================================================

def ingest_folder_incremental(folder: str) -> str:
    """
    Ingest PDFs from folder with incremental processing.
    Only processes new or modified files.
    
    Args:
        folder: Path to folder containing PDFs
        
    Returns:
        Path to saved index directory
    """
    print(f"\n📂 Scanning folder: {folder}\n")
    
    # Analyze files
    files_to_process, skipped_files, processed_files = analyze_files_to_process(folder)
    
    total_pdfs = len(files_to_process) + len(skipped_files)
    
    # Print summary
    print(f"📊 Ingestion Plan:")
    print(f"   - Total PDFs found: {total_pdfs}")
    print(f"   - Already processed (unchanged): {len(skipped_files)}")
    print(f"   - New or modified: {len(files_to_process)}")
    
    if skipped_files:
        print(f"\n⏭️  Skipping unchanged files:")
        for name in skipped_files[:5]:  # Show first 5
            print(f"      ✓ {name}")
        if len(skipped_files) > 5:
            print(f"      ... and {len(skipped_files) - 5} more")
    
    # If nothing to process, exit early
    if not files_to_process:
        print(f"\n✅ All files already up-to-date! Nothing to do.")
        print(f"\n💡 To reprocess all files, delete: {TRACKING_FILE}")
        return INDEX_DIR
    
    print(f"\n🔄 Processing {len(files_to_process)} file(s)...\n")
    
    # Load existing store or create new one
    if os.path.exists(INDEX_DIR) and len(skipped_files) > 0:
        print("📂 Loading existing index...")
        try:
            store = FaissStore.load(INDEX_DIR)
            print(f"   ✅ Loaded existing index with {len(store.meta)} chunks\n")
        except Exception as e:
            print(f"   ⚠️  Could not load existing index: {e}")
            print("   Creating new index instead...\n")
            dim = embed_texts(["dimension probe"]).shape[1]
            store = FaissStore(dim=dim)
    else:
        print("📂 Creating new index...\n")
        dim = embed_texts(["dimension probe"]).shape[1]
        store = FaissStore(dim=dim)
    
    # Process files
    ingestion_manifest = []
    total_chunks_added = 0
    
    for pdf_path in tqdm(files_to_process, desc="Processing PDFs"):
        try:
            doc_name = os.path.basename(pdf_path)
            doc_id = hashlib.md5(doc_name.encode()).hexdigest()
            
            # Extract pages
            pages = load_pdf_pages(pdf_path)
            
            if not pages:
                print(f"\n⚠️  Warning: No pages extracted from {doc_name}")
                continue
            
            chunk_counter = 0
            
            # Process each page
            for page in pages:
                page_num = page["page"]
                text = page["text"]
                
                if not text.strip():
                    continue
                
                section_heading = extract_section_heading(text)
                
                # Generate and embed chunks
                for chunk in chunk_text_generator(text, CHUNK_SIZE, CHUNK_OVERLAP):
                    if not chunk.strip():
                        continue
                    
                    chunk_id = stable_chunk_id(doc_id, page_num, chunk_counter, chunk)
                    
                    # Embed chunk
                    vec = embed_texts([chunk])
                    
                    # Create metadata
                    meta = {
                        "tenant_id": TENANT_ID,
                        "doc_id": doc_id,
                        "doc_name": doc_name,
                        "source_url": pdf_path,
                        "page_number": page_num,
                        "section_heading": section_heading,
                        "chunk_id": chunk_id,
                        "chunk_order": chunk_counter,
                        "text": chunk
                    }
                    
                    # Add to store
                    store.add(vec, [meta])
                    chunk_counter += 1
            
            total_chunks_added += chunk_counter
            
            # Update tracking for this file
            processed_files[doc_name] = {
                "hash": get_file_hash(pdf_path),
                "processed_date": datetime.now().isoformat(),
                "chunks": chunk_counter,
                "pages": len(pages)
            }
            
            # Add to manifest
            ingestion_manifest.append({
                "doc_id": doc_id,
                "doc_name": doc_name,
                "pages": len(pages),
                "chunks": chunk_counter
            })
        
        except Exception as e:
            print(f"\n❌ Error processing {os.path.basename(pdf_path)}: {e}")
            continue
    
    # Save updated index
    print(f"\n💾 Saving updated index...")
    store.save(INDEX_DIR)
    print(f"   ✅ Index saved to: {INDEX_DIR}")
    
    # Save tracking file
    save_processed_files(processed_files)
    print(f"   ✅ Tracking file updated: {TRACKING_FILE}")
    
    # Save ingestion manifest
    manifest_path = os.path.join(os.path.dirname(INDEX_DIR), "ingestion_manifest.json")
    try:
        with open(manifest_path, "w", encoding="utf-8") as f:
            json.dump(ingestion_manifest, f, indent=2, ensure_ascii=False)
        print(f"   ✅ Manifest saved: {manifest_path}")
    except Exception as e:
        print(f"   ⚠️  Could not save manifest: {e}")
    
    return INDEX_DIR


# ============================================================================
# MAIN ENTRY POINT
# ============================================================================

def main():
    """Main entry point for ingestion script."""
    print("\n" + "=" * 60)
    print("🚀 RAG INCREMENTAL INGESTION SYSTEM")
    print("=" * 60)
    
    folder = "data"
    
    # Check if folder exists
    if not os.path.exists(folder):
        print(f"\n❌ Error: Folder '{folder}' not found!")
        print(f"   Please create it and add PDF files.")
        print("=" * 60 + "\n")
        return
    
    # Check for PDFs
    pdf_files = glob.glob(os.path.join(folder, "*.pdf"))
    
    if not pdf_files:
        print(f"\n❌ No PDF files found in '{folder}'!")
        print(f"   Please add PDF files to the folder.")
        print("=" * 60 + "\n")
        return
    
    try:
        # Run incremental ingestion
        index_path = ingest_folder_incremental(folder)
        
        # Show final statistics
        print("\n" + "=" * 60)
        print("✅ INGESTION COMPLETE!")
        print("=" * 60)
        
        # Load final index stats
        if os.path.exists(index_path):
            try:
                store = FaissStore.load(index_path)
                print(f"\n📊 Final Index Statistics:")
                print(f"   - Total chunks in index: {len(store.meta):,}")
            except:
                pass
        
        # Load tracking stats
        processed = load_processed_files()
        print(f"   - Total tracked files: {len(processed)}")
        
        print(f"\n📁 Output Locations:")
        print(f"   - Index: {index_path}")
        print(f"   - Tracking: {TRACKING_FILE}")
        
        print(f"\n💡 Next Steps:")
        print(f"   1. Test Q&A: python -m rag.qa")
        print(f"   2. Start API: uvicorn api:app --reload")
        print(f"   3. Add more PDFs to '{folder}/' and run again")
        
        print("\n" + "=" * 60 + "\n")
    
    except Exception as e:
        print("\n" + "=" * 60)
        print("❌ INGESTION FAILED!")
        print("=" * 60)
        print(f"\nError: {e}\n")
        
        import traceback
        print("Full traceback:")
        traceback.print_exc()
        print("\n" + "=" * 60 + "\n")


if __name__ == "__main__":
    main()
