"""
Enhanced Document Text Extraction Module

Extends the original document_extractor.py to handle additional file types
including CSV, images, and text files, with recursive folder processing.
"""

import os
import io
import sys
import json
import base64
import pathlib
from datetime import datetime
from typing import Dict, Tuple, List, Optional
import time
import PyPDF2
import pandas as pd
from docx2pdf import convert
import nltk
from nltk.tokenize import word_tokenize
import google.generativeai as genai
from google.generativeai.types import content_types as types
import anthropic
import logging
from PIL import Image
import csv

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Download required NLTK data
try:
    nltk.download('punkt', quiet=True)
    nltk.download('punkt_tab', quiet=True)
except:
    pass

llm_model = "claude"  # "gemini"

class EnhancedDocumentExtractor:
    """
    Enhanced document text extraction class that handles multiple file types
    and provides recursive folder processing capabilities.
    """
    
    def __init__(self, llm_api_key: str):
        """
        Initialize the enhanced document extractor.
        
        Args:
            llm_api_key (str): API key for Claude/Gemini (used for scanned documents and images)
        """
        self.llm_api_key = llm_api_key
        if llm_model == "gemini":
            self.gemini_client = genai.Client(api_key=llm_api_key) if llm_api_key else None
        elif llm_model == "claude":
            self.claude_client = anthropic.Anthropic(api_key=llm_api_key) if llm_api_key else None
        self._extraction_cache = {}  # Cache to avoid re-extraction
        
        # Supported file extensions
        self.supported_extensions = {
            'pdf': ['.pdf'],
            'excel': ['.xls', '.xlsx'],
            'word': ['.doc', '.docx'],
            'csv': ['.csv'],
            'text': ['.txt'],
            'image': ['.jpg', '.jpeg', '.png', '.tiff', '.tif', '.bmp']
        }
    
    def count_words(self, text: str) -> int:
        """Count words in text."""
        try:
            return len(word_tokenize(text))
        except:
            return len(text.split())
    
    def is_scanned_page(self, page) -> bool:
        """
        Determine if a PDF page is a scanned image using PyPDF2
        
        Args:
            page: PyPDF2 page object
            
        Returns:
            bool: True if the page appears to be scanned, False otherwise
        """
        text = page.extract_text()
        return not text or len(text.strip()) < 20

    def extract_text_from_pdf(self, pdf_path: str, batch_size: int = 5) -> str:
        """
        Extract text content from PDF files, using Claude for scanned pages in batches
        and PyPDF2 for pages with selectable text
        
        Args:
            pdf_path (str): Path to the PDF file
            batch_size (int): Maximum number of scanned pages to send to Claude in one request
            
        Returns:
            str: Extracted text content
        """
        extracted_text = {}  # Dictionary to store text by page number
        scanned_pages = []
        
        try:
            with open(pdf_path, 'rb') as file:
                pdf_reader = PyPDF2.PdfReader(file)
                total_pages = len(pdf_reader.pages)
                
                logger.info(f"Processing PDF {os.path.basename(pdf_path)} with {total_pages} pages")
                
                # First pass: process selectable text pages and identify scanned pages
                for page_num in range(total_pages):
                    page = pdf_reader.pages[page_num]
                    
                    if self.is_scanned_page(page):
                        scanned_pages.append(page_num)
                        logger.info(f"Page {page_num + 1} appears to be a scanned page.")
                    else:
                        text = page.extract_text()
                        extracted_text[page_num] = f"--- Doc Page Number: {page_num + 1} ---\n{text}\n\n"
                
                logger.info(f"Found {len(scanned_pages)} scanned pages in the document")
                
                # Second pass: process scanned pages in batches with Claude
                if scanned_pages and self.claude_client:
                    for i in range(0, len(scanned_pages), batch_size):
                        batch = scanned_pages[i:i+batch_size]
                        
                        if not batch:
                            continue
                        
                        logger.info(f"Processing batch of {len(batch)} scanned pages: {[p+1 for p in batch]}")
                        
                        # Create a writer for this batch
                        writer = PyPDF2.PdfWriter()
                        
                        # Add each page in the batch
                        for page_num in batch:
                            writer.add_page(pdf_reader.pages[page_num])
                        
                        # Write batch to a memory buffer
                        output_buffer = io.BytesIO()
                        writer.write(output_buffer)
                        output_buffer.seek(0)
                        
                        # Prepare prompt for Claude
                        prompt = f"""
                        This message contains {len(batch)} scanned pages from a PDF document.
                        
                        For EACH PAGE, extract ALL text content visible in it. Format your response with clear page number headers like:
                        
                        "=== NEW PAGE ==="
                        [extracted text for page 1]
                        
                        "=== NEW PAGE ==="
                        [extracted text for page 2]
                        
                        And so on for each page. Format the text naturally, preserving paragraphs, bullet points, and other structural elements.
                        Please just give out the extracted text as output without any additional commentary.

                        NOTE that the text in the scanned pages can be in any language.
                        """
                        
                        # Encode PDF as base64 for Claude
                        pdf_base64 = base64.b64encode(output_buffer.read()).decode('utf-8')

                        response = self.claude_client.messages.create(
                            model="claude-haiku-4-5",
                            max_tokens=4000,
                            temperature=0.2,
                            system="You are an expert at extracting text of any language from scanned documents. Extract ALL text visible on EACH page, preserving the original formatting as much as possible.",
                            messages=[
                                {
                                    "role": "user", 
                                    "content": [
                                        {
                                            "type": "text",
                                            "text": prompt
                                        },
                                        {
                                            "type": "document",
                                            "source": {
                                                "type": "base64",
                                                "media_type": "application/pdf",
                                                "data": pdf_base64
                                            }
                                        }
                                    ]
                                }
                            ]
                        )

                        llm_text_out = response.content[0].text
                        
                        # Parse Claude's response to extract text for each page
                        scanned_pages_text = llm_text_out.split("=== NEW PAGE ===")
                        
                        # Remove the first element if it's empty or contains intro text
                        if scanned_pages_text and (len(scanned_pages_text) > len(batch)):
                            scanned_pages_text = scanned_pages_text[1:]
                        
                        # Map extracted text to corresponding page numbers
                        for j, page_num in enumerate(batch):
                            if j < len(scanned_pages_text):
                                page_content = scanned_pages_text[j].strip()
                                if page_content:
                                    extracted_text[page_num] = f"--- Doc Page Number: {page_num + 1} ---\n{page_content}\n\n"
                                else:
                                    extracted_text[page_num] = f"--- Doc Page Number: {page_num + 1} ---\n[No text content extracted]\n\n"
                            else:
                                extracted_text[page_num] = f"--- Doc Page Number: {page_num + 1} ---\n[Could not extract page content]\n\n"
                        
                        # Add small delay between API calls to avoid rate limiting
                        time.sleep(1)
        
        except Exception as e:
            logger.info(f"Error extracting text from PDF {pdf_path}: {str(e)}")
            return f"Error processing PDF: {str(e)}"
        
        # Combine all pages in correct order
        combined_text = ""
        for page_num in range(total_pages):
            if page_num in extracted_text:
                combined_text += extracted_text[page_num]
        
        logger.info(f"Processing complete. Processed {total_pages} pages, including {len(scanned_pages)} scanned pages.")
        
        return combined_text

    def extract_text_from_excel(self, excel_filepath: str, separator: str = "  ||  ") -> str:
        """
        Converts an Excel file to a formatted text file, maintaining table structure.

        Args:
            excel_filepath (str): Path to the Excel file.
            separator (str, optional): Separator character between columns. Defaults to "  ||  ".
            
        Returns:
            str: Formatted text representation of the Excel file
        """
        try:
            # Read all sheets
            all_sheets = pd.read_excel(excel_filepath, sheet_name=None, header=0)
            combined_text = ""
            
            for sheet_name, df in all_sheets.items():
                combined_text += f"\n--- Sheet: {sheet_name} ---\n"
                df = df.fillna('')
                df = df.astype(str)

                col_widths = [min(50, max(len(str(x)) for x in df[col].values)) for col in df.columns]
                
                # Write header row with separators
                header_line = separator.join(
                    f"{col:<{width}}" for col, width in zip(df.columns, col_widths)
                )
                # Write data rows with separators
                data_lines = []
                for _, row in df.iterrows():
                    data_line = separator.join(
                        f"{str(cell):<{width}}" for cell, width in zip(row, col_widths)
                    )
                    data_lines.append(data_line)

                combined_text += '\n'.join([header_line] + data_lines) + "\n\n"
            
            return combined_text
        except Exception as e:
            logger.info(f"Error extracting from Excel {excel_filepath}: {str(e)}")
            return f"Error processing Excel: {str(e)}"
    
    def extract_text_from_csv(self, csv_filepath: str) -> str:
        """
        Extract text from CSV file.
        
        Args:
            csv_filepath (str): Path to the CSV file
            
        Returns:
            str: Formatted text representation of the CSV file
        """
        try:
            df = pd.read_csv(csv_filepath)
            df = df.fillna('')
            df = df.astype(str)
            
            # Convert to formatted text
            return df.to_string(index=False)
        except Exception as e:
            logger.info(f"Error extracting from CSV {csv_filepath}: {str(e)}")
            return f"Error processing CSV: {str(e)}"
    
    def extract_text_from_txt(self, txt_filepath: str) -> str:
        """
        Extract text from text file.
        
        Args:
            txt_filepath (str): Path to the text file
            
        Returns:
            str: Text content of the file
        """
        try:
            with open(txt_filepath, 'r', encoding='utf-8') as file:
                return file.read()
        except UnicodeDecodeError:
            try:
                with open(txt_filepath, 'r', encoding='latin-1') as file:
                    return file.read()
            except Exception as e:
                logger.info(f"Error reading text file {txt_filepath}: {str(e)}")
                return f"Error processing text file: {str(e)}"
        except Exception as e:
            logger.info(f"Error extracting from text file {txt_filepath}: {str(e)}")
            return f"Error processing text file: {str(e)}"
    
    def extract_text_from_image(self, image_path: str) -> str:
        """
        Extract text from image file using Claude.
        
        Args:
            image_path (str): Path to the image file
            
        Returns:
            str: Extracted text content
        """
        try:
            if not self.claude_client:
                return "Error: No Claude API key provided for image processing"
            
            # Read and encode image
            with open(image_path, 'rb') as image_file:
                image_data = base64.b64encode(image_file.read()).decode('utf-8')
            
            # Determine image type
            image_ext = os.path.splitext(image_path)[1].lower()
            mime_type_map = {
                '.jpg': 'image/jpeg',
                '.jpeg': 'image/jpeg',
                '.png': 'image/png',
                '.tiff': 'image/tiff',
                '.tif': 'image/tiff',
                '.bmp': 'image/bmp'
            }
            mime_type = mime_type_map.get(image_ext, 'image/jpeg')
            
            response = self.claude_client.messages.create(
                model="claude-haiku-4-5",
                max_tokens=4000,
                temperature=0.2,
                system="You are an expert at extracting text from images. Extract ALL text visible in the image, preserving the original formatting as much as possible.",
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": "Extract all text content from this image. Preserve formatting, structure, and layout as much as possible. If this appears to be a document, maintain the document structure."
                            },
                            {
                                "type": "image",
                                "source": {
                                    "type": "base64",
                                    "media_type": mime_type,
                                    "data": image_data
                                }
                            }
                        ]
                    }
                ]
            )
            
            return response.content[0].text
            
        except Exception as e:
            logger.info(f"Error extracting text from image {image_path}: {str(e)}")
            return f"Error processing image: {str(e)}"
    
    def extract_text_from_word(self, word_path: str, temp_dir: str) -> str:
        """
        Extract text from Word document by converting to PDF first.
        
        Args:
            word_path (str): Path to the Word document
            temp_dir (str): Directory to store temporary PDF file
            
        Returns:
            str: Extracted text content
        """
        try:
            base_name = os.path.splitext(os.path.basename(word_path))[0]
            pdf_filename = os.path.join(temp_dir, f"{base_name}.pdf")
            
            convert(word_path, pdf_filename)
            text = self.extract_text_from_pdf(pdf_filename)
            
            # Clean up temporary PDF
            if os.path.exists(pdf_filename):
                os.remove(pdf_filename)
                
            return text
        except Exception as e:
            logger.info(f"Error extracting from Word {word_path}: {str(e)}")
            return f"Error processing Word document: {str(e)}"
    
    def get_file_type(self, file_path: str) -> str:
        """
        Determine the file type based on extension.
        
        Args:
            file_path (str): Path to the file
            
        Returns:
            str: File type category
        """
        ext = os.path.splitext(file_path)[1].lower()
        
        for file_type, extensions in self.supported_extensions.items():
            if ext in extensions:
                return file_type
        
        return "unsupported"
    
    def extract_text_from_file(self, file_path: str, temp_dir: str = None) -> str:
        """
        Extract text from a single file based on its type.
        
        Args:
            file_path (str): Path to the file
            temp_dir (str): Temporary directory for intermediate files
            
        Returns:
            str: Extracted text content
        """
        file_type = self.get_file_type(file_path)
        
        if temp_dir is None:
            temp_dir = os.path.dirname(file_path)
        
        try:
            if file_type == "pdf":
                return self.extract_text_from_pdf(file_path)
            elif file_type == "excel":
                return self.extract_text_from_excel(file_path)
            elif file_type == "word":
                return self.extract_text_from_word(file_path, temp_dir)
            elif file_type == "csv":
                return self.extract_text_from_csv(file_path)
            elif file_type == "text":
                return self.extract_text_from_txt(file_path)
            elif file_type == "image":
                return self.extract_text_from_image(file_path)
            else:
                return f"Unsupported file type: {file_type}"
        except Exception as e:
            logger.info(f"Error extracting text from {file_path}: {str(e)}")
            return f"Error processing file: {str(e)}"
    
    def find_documents_recursively(self, root_dir: str) -> List[str]:
        """
        Recursively find all supported document files in a directory tree.
        
        Args:
            root_dir (str): Root directory to search
            
        Returns:
            List[str]: List of file paths
        """
        found_files = []
        
        for root, dirs, files in os.walk(root_dir):
            for file in files:
                file_path = os.path.join(root, file)
                file_type = self.get_file_type(file_path)
                
                if file_type != "unsupported":
                    found_files.append(file_path)
        
        return found_files
    
    def extract_documents_text_recursive(self, root_dir: str, cache_key: Optional[str] = None) -> Tuple[Dict[str, str], bool, str]:
        """
        Recursively extract text from all supported documents in a directory tree.
        
        Args:
            root_dir (str): Root directory containing documents to process
            cache_key (str, optional): Key for caching results
            
        Returns:
            Tuple[Dict[str, str], bool, str]: 
                - Dictionary mapping file paths to extracted text
                - Success status
                - Status message
        """
        # Check cache first
        if cache_key and cache_key in self._extraction_cache:
            logger.info(f"Using cached extraction results for {cache_key}")
            return self._extraction_cache[cache_key]
        
        if not os.path.exists(root_dir):
            return {}, False, f"Directory {root_dir} does not exist"
        
        # Find all supported files recursively
        all_files = self.find_documents_recursively(root_dir)
        
        if not all_files:
            return {}, False, f"No supported document files found in {root_dir}"

        logger.info(f"Found {len(all_files)} supported files to process")
        
        documents_text = {}
        successful_files = []
        failed_files = []
        start_time = datetime.now()

        # Process all files
        for file_path in all_files:
            try:
                logger.info(f"Processing {os.path.basename(file_path)}...")
                text = self.extract_text_from_file(file_path, temp_dir=os.path.dirname(file_path))
                
                if text and not text.startswith("Error"):
                    documents_text[file_path] = text
                    successful_files.append(file_path)
                    logger.info(f"Successfully extracted text from {os.path.basename(file_path)}")
                else:
                    failed_files.append(file_path)
                    logger.info(f"Failed to extract text from {os.path.basename(file_path)}")
            except Exception as e:
                logger.info(f"Error processing {file_path}: {e}")
                failed_files.append(file_path)

        processing_time = datetime.now() - start_time
        
        if not documents_text:
            result = {}, False, "No text was successfully extracted from any documents"
        else:
            total_chars = sum(len(text) for text in documents_text.values())
            result = documents_text, True, f"Successfully extracted text from {len(successful_files)} documents ({total_chars} total characters) in {processing_time}"
        
        # Cache the result
        if cache_key:
            self._extraction_cache[cache_key] = result
        
        return result


# Factory function for easy initialization
def create_enhanced_document_extractor(llm_api_key: str) -> EnhancedDocumentExtractor:
    """
    Factory function to create an EnhancedDocumentExtractor instance.
    
    Args:
        llm_api_key (str): API key for Claude/Gemini
        
    Returns:
        EnhancedDocumentExtractor: Initialized extractor instance
    """
    return EnhancedDocumentExtractor(llm_api_key)