# app.py
import streamlit as st
import os
import json
from pathlib import Path
from datetime import datetime

from rag.vectorstore import FaissStore
from rag.qa import embed_query, answer_question_with_llm
from rag.report import build_report_from_query
from rag.render import render_docx
from rag import config

# ============================================
# PAGE CONFIG
# ============================================
st.set_page_config(
    page_title="RAG Annual Reports System",
    page_icon="📊",
    layout="wide"
)

# ============================================
# CUSTOM CSS
# ============================================
st.markdown("""
<style>
    .main-title {
        font-size: 2.5rem;
        font-weight: bold;
        color: #1e3a8a;
        text-align: center;
        padding: 1rem 0;
    }
    .stButton>button {
        width: 100%;
        background-color: #1e3a8a;
        color: white;
        font-weight: bold;
        padding: 0.75rem;
        border-radius: 0.5rem;
    }
    .stButton>button:hover {
        background-color: #2563eb;
    }
    .info-box {
        background-color: #f0f9ff;
        border-left: 4px solid #3b82f6;
        padding: 1rem;
        border-radius: 0.5rem;
        margin: 1rem 0;
    }
    .success-box {
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        color: white;
        padding: 1.5rem;
        border-radius: 0.75rem;
        margin: 1rem 0;
        box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
        font-size: 1.05rem;
        line-height: 1.6;
    }
    .warning-box {
        background-color: #fffbeb;
        border-left: 4px solid #f59e0b;
        padding: 1rem;
        border-radius: 0.5rem;
        margin: 1rem 0;
    }
</style>
""", unsafe_allow_html=True)

# ============================================
# INITIALIZE SESSION STATE
# ============================================
if 'store' not in st.session_state:
    if os.path.exists(config.INDEX_DIR):
        try:
            st.session_state.store = FaissStore.load(config.INDEX_DIR)
            st.session_state.index_loaded = True
        except Exception as e:
            st.session_state.index_loaded = False
            st.session_state.error = str(e)
    else:
        st.session_state.index_loaded = False

if 'qa_history' not in st.session_state:
    st.session_state.qa_history = []

if 'report_history' not in st.session_state:
    st.session_state.report_history = []

# ============================================
# HELPER FUNCTIONS
# ============================================

def get_index_stats():
    """Get statistics about indexed documents."""
    if not st.session_state.index_loaded:
        return None
    
    store = st.session_state.store
    total_chunks = len(store.meta)
    
    doc_ids = set()
    pages = set()
    
    for meta in store.meta:
        doc_ids.add(meta.get('doc_id', 'Unknown'))
        pages.add((meta.get('doc_id'), meta.get('page_number')))
    
    return {
        "total_chunks": total_chunks,
        "total_documents": len(doc_ids),
        "total_pages": len(pages),
        "documents": list(doc_ids)
    }

# ============================================
# HEADER
# ============================================
st.markdown('<p class="main-title">📊 RAG Annual Reports System</p>', unsafe_allow_html=True)
st.markdown("---")

# ============================================
# SIDEBAR - SYSTEM STATUS
# ============================================
with st.sidebar:
    st.header("🎯 System Status")
    
    if st.session_state.index_loaded:
        st.success("✅ Index Loaded")
        
        stats = get_index_stats()
        if stats:
            st.metric("Total Chunks", f"{stats['total_chunks']:,}")
            st.metric("Documents", stats['total_documents'])
            st.metric("Pages Indexed", f"{stats['total_pages']:,}")
            
            with st.expander("📚 Indexed Documents"):
                for doc_id in stats['documents']:
                    doc_name = Path(doc_id).name
                    st.text(f"• {doc_name}")
    else:
        st.error("❌ Index Not Loaded")
        st.info("Run `python -m rag.ingest` to create the index.")
    
    st.markdown("---")
    
    st.header("⚙️ Configuration")
    st.text(f"Embedding: {config.EMBED_MODEL}")
    st.text(f"LLM: {config.LLM_MODEL}")
    st.text(f"Chunk Size: {config.CHUNK_SIZE}")
    st.text(f"Overlap: {config.CHUNK_OVERLAP}")
    
    st.markdown("---")
    
    if st.button("🔄 Reload Index"):
        if os.path.exists(config.INDEX_DIR):
            try:
                st.session_state.store = FaissStore.load(config.INDEX_DIR)
                st.session_state.index_loaded = True
                st.success("✅ Index reloaded!")
                st.rerun()
            except Exception as e:
                st.error(f"❌ Error: {e}")
        else:
            st.error("❌ Index not found!")

# ============================================
# MAIN NAVIGATION
# ============================================
tab1, tab2, tab3 = st.tabs(["🔍 Q&A", "📝 Generate Report", "📜 History"])

# ============================================
# TAB 1: Q&A
# ============================================
with tab1:
    st.header("🔍 Ask Questions About Your Documents")
    st.markdown("Enter your question below and get AI-powered answers with source citations.")
    
    # Question input
    question = st.text_input(
        "Your Question:",
        placeholder="e.g., Who are the directors of Genomic Valley?",
        key="qa_question_input"
    )
    
    col1, col2 = st.columns([3, 1])
    
    with col1:
        ask_button = st.button("🚀 Get Answer", key="ask_btn", type="primary")
    
    with col2:
        k_value = st.number_input("Top K", min_value=1, max_value=50, value=8, key="qa_k")
    
    if ask_button:
        if not question:
            st.warning("⚠️ Please enter a question.")
        elif not st.session_state.index_loaded:
            st.error("❌ Index not loaded. Please run ingestion first.")
        else:
            with st.spinner("🔎 Searching through documents..."):
                try:
                    # Retrieve chunks
                    qvec = embed_query(question)
                    results = st.session_state.store.search(qvec, k=k_value)
                    
                    if not results:
                        st.warning("❌ No relevant information found.")
                    else:
                        # Display retrieved evidence
                        st.markdown("### 📄 Retrieved Evidence")
                        
                        with st.expander(f"View {len(results)} Retrieved Chunks", expanded=False):
                            for idx, r in enumerate(results, 1):
                                meta = r["metadata"]
                                score = r.get("score", 0.0)
                                
                                st.markdown(f"**Chunk {idx}** | {meta.get('doc_name', 'Unknown')} | Page {meta.get('page_number', '?')} | Score: {score:.3f}")
                                st.text_area(
                                    f"Content {idx}",
                                    meta['text'],
                                    height=100,
                                    key=f"chunk_display_{idx}",
                                    label_visibility="collapsed"
                                )
                                st.markdown("---")
                        
                        # Generate answer
                        with st.spinner("🤖 Generating answer..."):
                            answer = answer_question_with_llm(question, results)
                            
                            st.markdown("### ✅ Answer")
                            st.markdown(f'<div class="success-box">{answer}</div>', unsafe_allow_html=True)
                            
                            # Save to history
                            st.session_state.qa_history.append({
                                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                                "question": question,
                                "answer": answer,
                                "num_sources": len(results)
                            })
                            
                            # Download option
                            answer_data = {
                                "question": question,
                                "timestamp": datetime.now().isoformat(),
                                "answer": answer,
                                "sources": [
                                    {
                                        "doc": r["metadata"].get("doc_name"),
                                        "page": r["metadata"].get("page_number"),
                                        "chunk_id": r["metadata"].get("chunk_id"),
                                        "score": r.get("score", 0.0)
                                    }
                                    for r in results
                                ]
                            }
                            
                            st.download_button(
                                "📥 Download Answer (JSON)",
                                data=json.dumps(answer_data, indent=2),
                                file_name=f"answer_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
                                mime="application/json"
                            )
                
                except Exception as e:
                    st.error(f"❌ Error: {str(e)}")

# ============================================
# TAB 2: GENERATE REPORT
# ============================================
with tab2:
    st.header("📝 Generate Structured Reports")
    st.markdown("Enter a comprehensive query to generate a detailed multi-section report.")
    
    # Report query input
    report_query = st.text_area(
        "Report Query / Prompt:",
        placeholder="e.g., Generate a comprehensive analysis of all companies covering management, shareholding, financial performance, and risks.",
        height=120,
        key="report_query_input"
    )
    
    col1, col2 = st.columns([2, 1])
    
    with col1:
        report_filename = st.text_input(
            "Report Filename:",
            value="Generated_Report.docx",
            key="report_filename_input"
        )
    
    with col2:
        k_report = st.number_input(
            "Chunks to Retrieve:",
            min_value=10,
            max_value=200,
            value=100,
            step=10,
            key="report_k"
        )
    
    generate_button = st.button("📊 Generate Report", key="generate_report_btn", type="primary")
    
    if generate_button:
        if not report_query:
            st.warning("⚠️ Please enter a report query.")
        elif not st.session_state.index_loaded:
            st.error("❌ Index not loaded. Please run ingestion first.")
        else:
            progress_bar = st.progress(0)
            status_text = st.empty()
            
            try:
                # Step 1: Retrieve chunks
                status_text.info("🔍 Retrieving relevant chunks...")
                progress_bar.progress(20)
                
                # Step 2: Build report
                status_text.info("🤖 Generating report sections with LLM...")
                progress_bar.progress(40)
                
                report_obj = build_report_from_query(
                    st.session_state.store,
                    report_query,
                    k=k_report
                )
                
                progress_bar.progress(70)
                
                # Step 3: Render DOCX
                status_text.info("📄 Rendering DOCX document...")
                
                output_dir = Path("outputs/reports")
                output_dir.mkdir(parents=True, exist_ok=True)
                
                report_path = output_dir / report_filename
                final_path = render_docx(report_obj, str(report_path))
                
                progress_bar.progress(100)
                status_text.success(f"✅ Report generated successfully!")
                
                # Success message
                st.success(f"✅ Report saved at: `{final_path}`")
                
                # Save to history
                st.session_state.report_history.append({
                    "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    "query": report_query,
                    "filename": final_path,
                    "docs": report_obj.get('docs', [])
                })
                
                # Display report preview
                st.markdown("### 📋 Report Preview")
                
                st.markdown(f"**Query:** {report_obj['query']}")
                st.markdown(f"**Documents Analyzed:** {len(report_obj['docs'])}")
                
                for section_name, content in report_obj['sections'].items():
                    with st.expander(f"**{section_name}**"):
                        st.markdown(content[:500] + "..." if len(content) > 500 else content)
                
                # Download button
                if os.path.exists(final_path):
                    with open(final_path, "rb") as f:
                        st.download_button(
                            "📥 Download Report (DOCX)",
                            data=f.read(),
                            file_name=report_filename,
                            mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
                        )
            
            except Exception as e:
                st.error(f"❌ Error generating report: {str(e)}")
                progress_bar.empty()
                status_text.empty()

# ============================================
# TAB 3: HISTORY
# ============================================
with tab3:
    st.header("📜 Query & Report History")
    
    col1, col2 = st.columns(2)
    
    with col1:
        st.subheader("🔍 Q&A History")
        if st.session_state.qa_history:
            for idx, item in enumerate(reversed(st.session_state.qa_history), 1):
                with st.expander(f"Q{len(st.session_state.qa_history) - idx + 1}: {item['question'][:50]}..."):
                    st.markdown(f"**Time:** {item['timestamp']}")
                    st.markdown(f"**Question:** {item['question']}")
                    st.markdown(f"**Sources:** {item['num_sources']} chunks")
                    st.markdown(f"**Answer:**")
                    st.info(item['answer'][:300] + "..." if len(item['answer']) > 300 else item['answer'])
        else:
            st.info("No Q&A history yet.")
    
    with col2:
        st.subheader("📝 Report History")
        if st.session_state.report_history:
            for idx, item in enumerate(reversed(st.session_state.report_history), 1):
                with st.expander(f"Report {len(st.session_state.report_history) - idx + 1}: {item['query'][:50]}..."):
                    st.markdown(f"**Time:** {item['timestamp']}")
                    st.markdown(f"**Query:** {item['query']}")
                    st.markdown(f"**File:** `{Path(item['filename']).name}`")
                    st.markdown(f"**Documents:** {len(item['docs'])}")
        else:
            st.info("No report history yet.")
    
    # Clear history buttons
    st.markdown("---")
    col1, col2 = st.columns(2)
    
    with col1:
        if st.button("🗑️ Clear Q&A History"):
            st.session_state.qa_history = []
            st.success("Q&A history cleared!")
            st.rerun()
    
    with col2:
        if st.button("🗑️ Clear Report History"):
            st.session_state.report_history = []
            st.success("Report history cleared!")
            st.rerun()

# ============================================
# FOOTER
# ============================================
st.markdown("---")
st.markdown(
    "<div style='text-align: center; color: #64748b; padding: 1rem;'>"
    "Built with Streamlit • Powered by OpenAI & FAISS<br>"
    f"Version 1.0 | {datetime.now().year}"
    "</div>",
    unsafe_allow_html=True
)
