"""
Unified Document Text Extraction Module

This module provides comprehensive text extraction capabilities for tender RFP documents.
It handles PDF, Excel, Word documents and scanned images, returning both individual 
document texts and merged content.
"""

import os
import io
import sys
import json
import base64
import pathlib
from datetime import datetime
from typing import Dict, Tuple, List, Optional
import time
import PyPDF2
import pandas as pd
from docx2python import docx2python
import nltk
from nltk.tokenize import word_tokenize
import google.generativeai as genai
from google.generativeai.types import content_types as types
import anthropic
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


# Download required NLTK data
try:
    nltk.download('punkt', quiet=True)
    nltk.download('punkt_tab', quiet=True)
except:
    pass

llm_model = "claude" # "gemini"

class DocumentExtractor:
    """
    Unified document text extraction class that handles multiple file types
    and provides both individual and merged text outputs.
    """

    def __init__(self, llm_api_key: str, process_logger=None):
        """
        Initialize the document extractor.

        Args:
            llm_api_key (str): API key for Gemini (used for scanned documents)
            process_logger: Optional process logger for detailed logging
        """
        self.llm_api_key = llm_api_key
        self.process_logger = process_logger
        if(llm_model == "gemini"):
            self.gemini_client = genai.Client(api_key=llm_api_key) if llm_api_key else None
        elif(llm_model == "claude"):
            self.claude_client = anthropic.Anthropic(api_key=llm_api_key) if llm_api_key else None
        self._extraction_cache = {}  # Cache to avoid re-extraction

    def _log_info(self, msg):
        """Log info message to both module logger and process logger"""
        logger.info(msg)
        if self.process_logger:
            self.process_logger.info(msg)

    def _log_warning(self, msg):
        """Log warning message to both module logger and process logger"""
        logger.warning(msg)
        if self.process_logger:
            self.process_logger.warning(msg)

    def _log_error(self, msg):
        """Log error message to both module logger and process logger"""
        logger.error(msg)
        if self.process_logger:
            self.process_logger.error(msg)
    
    def count_words(self, text: str) -> int:
        """Count words in text."""
        try:
            return len(word_tokenize(text))
        except:
            return len(text.split())
    
    def is_scanned_page(self, page) -> bool:
        """
        Determine if a PDF page is a scanned image using PyPDF2
        
        Args:
            page: PyPDF2 page object
            
        Returns:
            bool: True if the page appears to be scanned, False otherwise
        """
        text = page.extract_text()
        return not text or len(text.strip()) < 20
    

    def extract_text_from_pdf(self, pdf_path: str, batch_size: int = 5) -> str:
        """
        Extract text content from PDF files, using Gemini for scanned pages in batches
        and PyPDF2 for pages with selectable text
        
        Args:
            pdf_path (str): Path to the PDF file
            batch_size (int): Maximum number of scanned pages to send to Gemini in one request
            
        Returns:
            str: Extracted text content
        """
        extracted_text = {}  # Dictionary to store text by page number
        scanned_pages = []
        
        try:
            with open(pdf_path, 'rb') as file:
                pdf_reader = PyPDF2.PdfReader(file)
                total_pages = len(pdf_reader.pages)
                
                self._log_info(f"Processing PDF {os.path.basename(pdf_path)} with {total_pages} pages")

                # First pass: process selectable text pages and identify scanned pages
                for page_num in range(total_pages):
                    page = pdf_reader.pages[page_num]

                    if self.is_scanned_page(page):
                        scanned_pages.append(page_num)
                        self._log_info(f"Page {page_num + 1} appears to be a scanned page.")
                    else:
                        text = page.extract_text()
                        extracted_text[page_num] = f"--- Doc Page Number: {page_num + 1} ---\n{text}\n\n"

                self._log_info(f"Found {len(scanned_pages)} scanned pages in the document")
                
                # Second pass: process scanned pages in batches with Gemini
                if scanned_pages:
                    for i in range(0, len(scanned_pages), batch_size):
                        batch = scanned_pages[i:i+batch_size]
                        
                        if not batch:
                            continue

                        self._log_info(f"Processing batch of {len(batch)} scanned pages: {[p+1 for p in batch]}")
                        
                        # Create a writer for this batch
                        writer = PyPDF2.PdfWriter()
                        
                        # Add each page in the batch
                        for page_num in batch:
                            writer.add_page(pdf_reader.pages[page_num])
                        
                        # Write batch to a memory buffer
                        output_buffer = io.BytesIO()
                        writer.write(output_buffer)
                        output_buffer.seek(0)
                        

                        # Prepare prompt for Gemini
                        page_list = ", ".join([str(p+1) for p in batch])
                        prompt = f"""
                        This message contains {len(batch)} scanned pages from a PDF document.
                        
                        For EACH PAGE, extract ALL text content visible in it. Format your response with clear page number headers like:
                        
                        "=== NEW PAGE ==="
                        [extracted text for page 1]
                        
                        "=== NEW PAGE ==="
                        [extracted text for page 2]
                        
                        And so on for each page. Format the text naturally, preserving paragraphs, bullet points, and other structural elements.
                        Please just give out the extracted text as output without any additional commentary.

                        NOTE that the text in the scanned pages can be in any language.
                        """
                        if(llm_model == 'gemini'):
                            # Call Gemini API with only the batch of scanned pages
                            batch_pdf_bytes = output_buffer.getvalue()
                            
                            response = self.gemini_client.models.generate_content(
                                model="gemini-2.5-flash-preview-04-17", 
                                contents=[
                                    types.Part.from_bytes(
                                        data=batch_pdf_bytes,
                                        mime_type='application/pdf',
                                    ),
                                    prompt
                                ]
                            )
                            
                            llm_text_out = response.text

                        elif(llm_model == 'claude'):
                            
                            # Encode PDF as base64 for Claude
                            pdf_base64 = base64.b64encode(output_buffer.read()).decode('utf-8')

                            response = self.claude_client.messages.create(
                                model="claude-3-5-haiku-latest",
                                max_tokens=4000,
                                temperature=0.2,
                                system="You are an expert at extracting text of any language from scanned documents. Extract ALL text visible on EACH page, preserving the original formatting as much as possible.",
                                messages=[
                                    {
                                        "role": "user", 
                                        "content": [
                                            {
                                                "type": "text",
                                                "text": prompt
                                            },
                                            {
                                                "type": "document",
                                                "source": {
                                                    "type": "base64",
                                                    "media_type": "application/pdf",
                                                    "data": pdf_base64
                                                }
                                            }
                                        ]
                                    }
                                ]
                            )

                            llm_text_out = response.content[0].text
                        
                        # Parse Gemini's response to extract text for each page
                        scanned_pages_text = llm_text_out.split("=== NEW PAGE ===")
                        
                        # Remove the first element if it's empty or contains intro text
                        if scanned_pages_text and (len(scanned_pages_text) > len(batch)):
                            scanned_pages_text = scanned_pages_text[1:]
                        
                        # Map extracted text to corresponding page numbers
                        for j, page_num in enumerate(batch):
                            if j < len(scanned_pages_text):
                                page_content = scanned_pages_text[j].strip()
                                if page_content:
                                    extracted_text[page_num] = f"--- Doc Page Number: {page_num + 1} ---\n{page_content}\n\n"
                                else:
                                    extracted_text[page_num] = f"--- Doc Page Number: {page_num + 1} ---\n[No text content extracted]\n\n"
                            else:
                                extracted_text[page_num] = f"--- Doc Page Number: {page_num + 1} ---\n[Could not extract page content]\n\n"
                        
                        # Add small delay between API calls to avoid rate limiting
                        time.sleep(1)
                # elif scanned_pages and not self.gemini_client:
                    # print("Warning: Scanned pages found but no Gemini API key provided. Skipping scanned pages.")
        
        except Exception as e:
            self._log_info(f"Error extracting text from PDF {pdf_path}: {str(e)}")
            return f"Error processing PDF: {str(e)}"

        # Combine all pages in correct order
        combined_text = ""
        for page_num in range(total_pages):
            if page_num in extracted_text:
                combined_text += extracted_text[page_num]

        self._log_info(f"Processing complete. Processed {total_pages} pages, including {len(scanned_pages)} scanned pages.")

        return combined_text


    def extract_table_from_excel(self, excel_filepath: str, separator: str = "  ||  ") -> str:
        """
        Converts an Excel file to a formatted text file, maintaining table structure.

        Args:
            excel_filepath (str): Path to the Excel file.
            separator (str, optional): Separator character between columns. Defaults to "  ||  ".
            
        Returns:
            str: Formatted text representation of the Excel file
        """
        try:
            df = pd.read_excel(excel_filepath, header=0)
            df = df.fillna('')
            df = df.astype(str)

            col_widths = [min(50, max(len(str(x)) for x in df[col].values)) for col in df.columns]
            
            # Write header row with separators
            header_line = separator.join(
                f"{col:<{width}}" for col, width in zip(df.columns, col_widths)
            )
            # Write data rows with separators
            data_lines = []
            for _, row in df.iterrows():
                data_line = separator.join(
                    f"{str(cell):<{width}}" for cell, width in zip(row, col_widths)
                )
                data_lines.append(data_line)

            return '\n'.join([header_line] + data_lines)
        except Exception as e:
            self._log_error(f"Error extracting from Excel {excel_filepath}: {str(e)}")
            return f"Error processing Excel: {str(e)}"

    def extract_text_from_word(self, word_path: str, temp_dir: str) -> str:
        """
        Extract text from Word document by converting to PDF first.
        
        Args:
            word_path (str): Path to the Word document
            temp_dir (str): Directory to store temporary PDF file
            
        Returns:
            str: Extracted text content
        """
        try:
            base_name = os.path.splitext(os.path.basename(word_path))[0]
            pdf_filename = os.path.join(temp_dir, f"{base_name}.pdf")
            
            # convert(word_path, pdf_filename)
            # text = self.extract_text_from_pdf(pdf_filename)

            text = self.get_structured_docx_content(word_path)
            
            # Clean up temporary PDF
            if os.path.exists(pdf_filename):
                os.remove(pdf_filename)

            return text
        except Exception as e:
            self._log_error(f"Error extracting from Word {word_path}: {str(e)}")
            return f"Error processing Word document: {str(e)}"
    
    def extract_documents_text(self, input_dir: str, cache_key: Optional[str] = None) -> Tuple[Dict[str, str], bool, str]:
        """
        Extract text from all supported documents in a directory.
        
        Args:
            input_dir (str): Directory containing documents to process
            cache_key (str, optional): Key for caching results
            
        Returns:
            Tuple[Dict[str, str], bool, str]: 
                - Dictionary mapping file paths to extracted text
                - Success status
                - Status message
        """
        # Check cache first
        if cache_key and cache_key in self._extraction_cache:
            self._log_info(f"Using cached extraction results for {cache_key}")
            return self._extraction_cache[cache_key]
        
        if not os.path.exists(input_dir):
            return {}, False, f"Directory {input_dir} does not exist"
        
        # Categorize files by type
        pdf_files = []
        excel_files = []
        word_files = []

        for filename in os.listdir(input_dir):
            file_path = os.path.join(input_dir, filename)
            if filename.endswith(('.xls', '.xlsx', '.XLS', '.XLSX')):
                excel_files.append(file_path)
            elif filename.endswith(('.pdf', '.PDF')):
                pdf_files.append(file_path)
            elif filename.endswith(('.doc', '.docx', '.DOC', '.DOCX')):
                word_files.append(file_path)
        
        if not pdf_files and not word_files and not excel_files:
            return {}, False, f"No supported document files found in {input_dir}"

        self._log_info(f"Found {len(pdf_files)} PDF files, {len(excel_files)} Excel files, {len(word_files)} Word files to process")

        documents_text = {}
        successful_files = []
        failed_files = []
        start_time = datetime.now()

        # Process PDFs
        for file_path in pdf_files:
            try:
                text = self.extract_text_from_pdf(file_path)
                if text and not text.startswith("Error"):
                    documents_text[file_path] = text
                    successful_files.append(file_path)
                    self._log_info(f"Successfully extracted text from {os.path.basename(file_path)}")
                else:
                    failed_files.append(file_path)
                    self._log_warning(f"Failed to extract text from {os.path.basename(file_path)}")
            except Exception as e:
                self._log_error(f"Error processing PDF {file_path}: {e}")
                failed_files.append(file_path)

        # Process Excel files
        for file_path in excel_files:
            try:
                text = self.extract_table_from_excel(file_path)
                if text and not text.startswith("Error"):
                    documents_text[file_path] = text
                    successful_files.append(file_path)
                    self._log_info(f"Successfully extracted text from {os.path.basename(file_path)}")
                else:
                    failed_files.append(file_path)
                    self._log_warning(f"Failed to extract text from {os.path.basename(file_path)}")
            except Exception as e:
                self._log_error(f"Error processing Excel {file_path}: {e}")
                failed_files.append(file_path)

        # Process Word files
        for file_path in word_files:
            try:
                text = self.extract_text_from_word(file_path, input_dir)
                if text and not text.startswith("Error"):
                    documents_text[file_path] = text
                    successful_files.append(file_path)
                    self._log_info(f"Successfully extracted text from {os.path.basename(file_path)}")
                else:
                    failed_files.append(file_path)
                    self._log_warning(f"Failed to extract text from {os.path.basename(file_path)}")
            except Exception as e:
                self._log_error(f"Error processing Word {file_path}: {e}")
                failed_files.append(file_path)

        processing_time = datetime.now() - start_time
        
        if not documents_text:
            result = {}, False, "No text was successfully extracted from any documents"
        else:
            total_chars = sum(len(text) for text in documents_text.values())
            result = documents_text, True, f"Successfully extracted text from {len(successful_files)} documents ({total_chars} total characters) in {processing_time}"
        
        # Cache the result
        if cache_key:
            self._extraction_cache[cache_key] = result
        
        return result
    
    def create_merged_text_file(self, documents_text: Dict[str, str], output_path: str) -> bool:
        """
        Create a merged text file from individual document texts.
        
        Args:
            documents_text (Dict[str, str]): Dictionary mapping file paths to text content
            output_path (str): Path where to save the merged text file
            
        Returns:
            bool: Success status
        """
        try:
            merged_text = "\n\n".join(documents_text.values())
            
            # Ensure output directory exists
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            
            with open(output_path, 'w', encoding="utf-8") as file:
                file.write(merged_text)

            self._log_info(f"Merged text ({len(merged_text)} chars) successfully written to {output_path}")
            return True
        except Exception as e:
            self._log_error(f"Error creating merged text file: {e}")
            return False
    
    def count_total_words(self, documents_text: Dict[str, str]) -> int:
        """
        Count the total number of words in all document texts.
        
        Args:
            documents_text (Dict[str, str]): Dictionary mapping file paths to their text content
            
        Returns:
            int: Total word count across all documents
        """
        total_word_count = 0
        
        for file_path, text_content in documents_text.items():
            if not text_content:
                continue
                
            word_count = self.count_words(text_content)
            total_word_count += word_count

            self._log_info(f"Document: {os.path.basename(file_path)} - {word_count:,} words")

        self._log_info(f"\nTotal words across all documents: {total_word_count:,}")
        return total_word_count

    def get_structured_docx_content(self, file_path):
        """Returns the content of a .docx file as a structured list."""
        try:
            with docx2python(file_path) as docx_content:
                return str(docx_content.body)
        except Exception as e:
            return f"Error reading file: {e}"


# Factory function for easy initialization
def create_document_extractor(llm_api_key: str) -> DocumentExtractor:
    """
    Factory function to create a DocumentExtractor instance.
    
    Args:
        llm_api_key (str): API key for Gemini
        
    Returns:
        DocumentExtractor: Initialized extractor instance
    """
    return DocumentExtractor(llm_api_key)


# Backward compatibility functions for existing code
def extract_RFP_text_compatible(input_dir: str, output_dir: str, llm_api_key: str) -> Tuple[bool, str]:
    """
    Backward compatible version of extract_RFP_text for bid_queries.py
    
    Args:
        input_dir (str): Directory containing input documents
        output_dir (str): Directory for output files
        llm_api_key (str): API key for LLM
        
    Returns:
        Tuple[bool, str]: Success status and message
    """

    merged_file_path = os.path.join(output_dir, "merged.txt")
    

    if os.path.exists(merged_file_path):
        return True, f"The file '{merged_file_path}' already exists."
    else:

        extractor = DocumentExtractor(llm_api_key)
        
        # Extract documents text
        documents_text, success, message = extractor.extract_documents_text(input_dir, cache_key=input_dir)

        if not success:
            return False, message
        
        documents_text_path = os.path.join(output_dir, "doc_text.json")
        with open(documents_text_path, 'w', encoding='utf-8') as f:
            json.dump(documents_text, f, ensure_ascii=False, indent=4)

        # Create merged text file
        merge_success = extractor.create_merged_text_file(documents_text, merged_file_path)
        
        if merge_success:
            return True, "Successfully created merged text file"
        else:
            return False, "Failed to create merged text file"


def extract_documents_text_compatible(input_dir: str, output_dir: str, llm_api_key: str, process_logger=None) -> Dict[str, str]:
    """
    Backward compatible function for tender_automation.py

    Args:
        input_dir (str): Directory containing documents
        llm_api_key (str): API key for LLM
        process_logger: Optional process logger for detailed logging

    Returns:
        Dict[str, str]: Dictionary mapping file paths to extracted text
    """
    # Helper function to log with both loggers
    def log_info(msg):
        logger.info(msg)
        if process_logger:
            process_logger.info(msg)

    def log_error(msg):
        logger.error(msg)
        if process_logger:
            process_logger.error(msg)

    def log_warning(msg):
        logger.warning(msg)
        if process_logger:
            process_logger.warning(msg)

    documents_text_path = os.path.join(output_dir, "doc_text.json")

    if os.path.exists(documents_text_path):
        log_info(f"The file '{documents_text_path}' already exists.")
        try:
            with open(documents_text_path, 'r', encoding='utf-8') as file:
                documents_text = json.load(file)
            log_info(f"Loaded existing document text from cache: {len(documents_text)} documents")
            return documents_text
        except FileNotFoundError:
            log_error(f"Error: File not found: {documents_text_path}")
            return None
        except json.JSONDecodeError:
            log_error(f"Error: Invalid JSON format in: {documents_text_path}")
            return None

    else:
        log_info(f"Starting document extraction from {input_dir}")
        extractor = DocumentExtractor(llm_api_key, process_logger=process_logger)

        log_info("Extracting text from documents...")
        documents_text, success, message = extractor.extract_documents_text(input_dir, cache_key=input_dir)

        log_info(f"Saving extracted text to {documents_text_path}")
        with open(documents_text_path, 'w', encoding='utf-8') as f:
            json.dump(documents_text, f, ensure_ascii=False, indent=4)

        if success:
            log_info(f"Document extraction successful: {message}")
            log_info(f"Extracted text from {len(documents_text)} documents")

            # Create merged text file
            merged_file_path = os.path.join(output_dir, "merged.txt")
            log_info(f"Creating merged text file at {merged_file_path}")
            merge_success = extractor.create_merged_text_file(documents_text, merged_file_path)

            if merge_success:
                log_info("Successfully created merged text file")
            else:
                log_warning("Failed to create merged text file")

            return documents_text
        else:
            log_error(f"Document extraction failed: {message}")
            return {}