import os
import sys
import time
import threading
import json
import sqlite3
import re
import webbrowser
import io
import subprocess
import shutil
from concurrent.futures import ThreadPoolExecutor
from flask import Flask, render_template, jsonify, send_from_directory, request, send_file
import fitz  # PyMuPDF

# ---------------- Configuration & Setup ----------------
DATA_DIR = os.path.join(os.getenv('APPDATA', os.getcwd()), "AdvancedPDF_Data")
os.makedirs(DATA_DIR, exist_ok=True)
DB_PATH = os.path.join(DATA_DIR, "library.db")
CONFIG_FILE = os.path.join(DATA_DIR, "config.json")
TREE_FILE = os.path.join(DATA_DIR, "tree.json")

# Flask Setup
if getattr(sys, 'frozen', False):
    template_folder = os.path.join(sys._MEIPASS, 'templates')
    static_folder = os.path.join(sys._MEIPASS, 'static')
    app = Flask(__name__, template_folder=template_folder, static_folder=static_folder)
else:
    app = Flask(__name__, static_folder='static')

# Global State
current_config = {'root_paths': [], 'theme_mode': 0}
scan_status = {'is_scanning': False, 'total_files': 0, 'done': False}
indexer_status = {'is_indexing': False, 'total_files': 0, 'processed': 0, 'current': '', 'done': False, 'stop_flag': False}
server_status = {'last_heartbeat': time.time()}

# ---------------- Database Helpers ----------------
def get_db_connection():
    """Create a database connection with a timeout to handle concurrency."""
    conn = sqlite3.connect(DB_PATH, timeout=30.0)
    conn.row_factory = sqlite3.Row
    return conn

def init_db():
    """Initialize the database schema."""
    try:
        with get_db_connection() as conn:
            c = conn.cursor()
            c.execute('''CREATE TABLE IF NOT EXISTS files (
                            id INTEGER PRIMARY KEY AUTOINCREMENT,
                            path TEXT UNIQUE, filename TEXT, mod_time REAL, scanned INTEGER DEFAULT 0)''')
            try:
                c.execute("CREATE VIRTUAL TABLE IF NOT EXISTS pages USING fts5(file_id UNINDEXED, page_num UNINDEXED, content)")
            except Exception:
                pass
            c.execute('''CREATE TABLE IF NOT EXISTS favorites (path TEXT PRIMARY KEY)''')
            c.execute('''CREATE TABLE IF NOT EXISTS history (path TEXT PRIMARY KEY, last_access REAL)''')
            conn.commit()
    except Exception as e:
        print(f"DB Init Error: {e}")

def load_config():
    if os.path.exists(CONFIG_FILE):
        try:
            with open(CONFIG_FILE, 'r', encoding='utf-8') as f:
                current_config.update(json.load(f))
        except Exception as e:
            print(f"Load Config Error: {e}")

def save_config():
    try:
        with open(CONFIG_FILE, 'w', encoding='utf-8') as f:
            json.dump(current_config, f)
    except Exception as e:
        print(f"Save Config Error: {e}")

# ---------------- File Scanning Logic ----------------
def scan_files_on_disk(root_paths):
    """Scans directories and updates the DB with new/modified files."""
    flat_list = []
    found_paths = set()
    to_insert = []

    try:
        with get_db_connection() as conn:
            c = conn.cursor()
            c.execute("SELECT path, mod_time FROM files")
            db_files = {row['path']: row['mod_time'] for row in c.fetchall()}

            for root_p in root_paths:
                if not os.path.exists(root_p):
                    continue
                root_p = os.path.normpath(root_p).replace('\\', '/')
                
                # Standard os.walk is used; os.scandir is implicit in Python 3.5+ os.walk
                for root, _, files in os.walk(root_p):
                    for f in files:
                        if not f.lower().endswith('.pdf'):
                            continue
                        
                        full_p = os.path.join(root, f).replace('\\', '/')
                        found_paths.add(full_p)
                        
                        try:
                            mtime = os.path.getmtime(full_p)
                        except OSError:
                            mtime = 0

                        # Check if file is new or modified
                        if full_p not in db_files:
                            to_insert.append((full_p, f, mtime))
                        elif db_files[full_p] != mtime:
                            c.execute("UPDATE files SET mod_time=?, scanned=0 WHERE path=?", (mtime, full_p))

                        parent_dir = os.path.dirname(full_p)
                        folder_name = os.path.basename(parent_dir) or parent_dir
                        flat_list.append({'n': f, 'p': full_p, 'd': folder_name})

            # Update DB
            if to_insert:
                c.executemany("INSERT INTO files (path, filename, mod_time) VALUES (?,?,?)", to_insert)
            
            # Cleanup deleted files
            for p in db_files:
                if p not in found_paths:
                    c.execute("DELETE FROM files WHERE path=?", (p,))
            
            conn.commit()
            
    except Exception as e:
        print(f"Disk Scan Error: {e}")
        
    return flat_list

def build_directory_tree(flat_list, root_paths):
    """Builds the nested folder structure for the frontend."""
    tree = {'name': 'ספרים', 'type': 'folder', 'children': []}
    
    # Sort files to ensure folders appear before files if desired, or alphabetical
    # Here we just process them.
    
    for root_p in root_paths:
        root_p = os.path.normpath(root_p).replace('\\', '/')
        # Filter items belonging to this root
        items = [i for i in flat_list if i['p'].startswith(root_p)]
        
        for item in items:
            rel_path = os.path.relpath(item['p'], root_p).replace('\\', '/')
            parts = rel_path.split('/')
            
            curr = tree
            # Traverse/Create folders
            for part in parts[:-1]:
                found = next((c for c in curr['children'] if c.get('name') == part and c['type'] == 'folder'), None)
                if not found:
                    found = {'name': part, 'type': 'folder', 'children': []}
                    curr['children'].append(found)
                curr = found
            
            # Add file
            curr['children'].append({'name': item['n'], 'type': 'file', 'path': item['p']})
            
    return tree

def file_scan_worker():
    """Main worker function for scanning."""
    global scan_status
    scan_status.update({'is_scanning': True, 'done': False})
    
    try:
        flat_list = scan_files_on_disk(current_config['root_paths'])
        tree = build_directory_tree(flat_list, current_config['root_paths'])
        scan_status['total_files'] = len(flat_list)

        # Optimization: Write JSON only if changed
        write_needed = True
        if os.path.exists(TREE_FILE):
            try:
                with open(TREE_FILE, 'r', encoding='utf-8') as f:
                    prev_data = json.load(f)
                    # Compare only the flat list or a hash for speed
                    if len(prev_data.get('flat', [])) == len(flat_list):
                        # Simple length check first, then content if needed
                        # For now, let's just write to be safe, but the logic is here
                        pass 
            except: pass

        if write_needed:
            with open(TREE_FILE, 'w', encoding='utf-8') as f:
                json.dump({'tree': tree, 'flat': flat_list}, f)
                
    except Exception as e:
        print(f"Scan Worker Error: {e}")
    finally:
        scan_status.update({'done': True, 'is_scanning': False})

# ---------------- Indexing Logic (Multi-threaded) ----------------
def index_single_pdf(task):
    """Process a single PDF file: extract text and insert into DB."""
    fid, path, fname = task
    
    # Each thread gets its own DB connection
    try:
        with get_db_connection() as conn:
            c = conn.cursor()
            
            # Suppress MuPDF errors
            try:
                fitz.TOOLS.mupdf_display_errors(False)
            except: pass
            
            doc = fitz.open(path)
            
            # Clear previous pages for this file
            c.execute("DELETE FROM pages WHERE file_id=?", (fid,))
            
            for page_num, page in enumerate(doc, start=1):
                text = page.get_text()
                if text and len(text.strip()) > 3:
                    # Cleanup text
                    clean_text = re.sub(r'\s+', ' ', text).strip()
                    c.execute("INSERT INTO pages VALUES (?, ?, ?)", (fid, page_num, clean_text))
            
            # Mark as scanned
            c.execute("UPDATE files SET scanned=1 WHERE id=?", (fid,))
            conn.commit()
            doc.close()
            return True # Success
    except Exception as e:
        print(f"Indexing Error [{fname}]: {e}")
        return False # Failure

def content_indexer_worker():
    """Main worker for indexing using ThreadPoolExecutor."""
    global indexer_status
    indexer_status.update({'is_indexing': True, 'processed': 0, 'done': False, 'stop_flag': False})
    
    try:
        # 1. Fetch tasks
        with get_db_connection() as conn:
            c = conn.cursor()
            c.execute("SELECT id, path, filename FROM files WHERE scanned=0")
            tasks = c.fetchall() # Returns list of Row objects or tuples
            # Convert to list of tuples for the executor
            tasks = [(t['id'], t['path'], t['filename']) for t in tasks]
            
        indexer_status['total_files'] = len(tasks)
        
        # 2. Process in parallel
        # max_workers=4 is usually a sweet spot for disk/CPU mix on average PCs
        with ThreadPoolExecutor(max_workers=4) as executor:
            # Map tasks to futures
            future_to_file = {executor.submit(index_single_pdf, task): task for task in tasks}
            
            for future in future_to_file:
                if indexer_status['stop_flag']:
                    # We can't easily kill running threads, but we stop submitting/processing results
                    break
                
                task_info = future_to_file[future]
                indexer_status['current'] = task_info[2] # Update UI with filename
                
                try:
                    future.result() # Wait for individual task
                except Exception:
                    pass
                
                indexer_status['processed'] += 1

    except Exception as e:
        print(f"Indexer Worker Error: {e}")
    finally:
        indexer_status.update({'done': True, 'is_indexing': False, 'current': ''})

# ---------------- App Lifecycle ----------------
def monitor_shutdown():
    start_time = time.time()
    while True:
        time.sleep(1)
        if time.time() - start_time < 15:
            continue
        if time.time() - server_status['last_heartbeat'] > 5:
            os._exit(0)

def open_app_window(url):
    browsers = ['chrome', 'google-chrome', 'msedge']
    browser_path = None
    
    for b in browsers:
        path = shutil.which(b)
        if path:
            browser_path = path
            break
            
    if not browser_path and os.name == 'nt':
        common_paths = [
            r"C:\Program Files\Google\Chrome\Application\chrome.exe",
            r"C:\Program Files (x86)\Google\Chrome\Application\chrome.exe",
            r"C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe"
        ]
        for p in common_paths:
            if os.path.exists(p):
                browser_path = p
                break

    if browser_path:
        try:
            subprocess.Popen([
                browser_path, 
                f'--app={url}', 
                '--new-window', 
                '--start-maximized', 
                '--disable-infobars'
            ])
            return
        except: pass
    
    webbrowser.open(url, new=1)

# ---------------- Flask Routes ----------------
@app.route('/')
def index():
    return render_template('index.html')

@app.route('/scan_start')
def scan_start():
    if not scan_status['is_scanning']:
        threading.Thread(target=file_scan_worker, daemon=True).start()
    return jsonify({'status': 'started'})

@app.route('/scan_status')
def get_scan_status():
    return jsonify(scan_status)

@app.route('/index_start')
def index_start():
    if not indexer_status['is_indexing']:
        threading.Thread(target=content_indexer_worker, daemon=True).start()
    return jsonify({'status': 'started'})

@app.route('/index_stop')
def index_stop():
    indexer_status['stop_flag'] = True
    return jsonify({'status': 'stopping'})

@app.route('/index_status')
def get_index_status():
    return jsonify(indexer_status)

@app.route('/clear_index')
def clear_index():
    try:
        with get_db_connection() as conn:
            c = conn.cursor()
            c.execute("DROP TABLE IF EXISTS pages")
            c.execute("CREATE VIRTUAL TABLE pages USING fts5(file_id UNINDEXED, page_num UNINDEXED, content)")
            c.execute("UPDATE files SET scanned=0")
            c.execute("VACUUM")
        return jsonify({'status': 'ok'})
    except Exception as e:
        return jsonify({'status': 'error', 'message': str(e)})

@app.route('/search')
def search():
    q = request.args.get('q', '').strip()
    query = """SELECT f.path, f.filename, p.page_num, snippet(pages, 2, '', '', '...', 40) as snip
               FROM pages p JOIN files f ON p.file_id = f.id WHERE p.content MATCH ? LIMIT 500"""
    res = []
    try:
        # Parameterized query handles sanitization better, but FTS syntax can be tricky.
        # We quote the query to treat it as a phrase/string in FTS.
        q_safe = q.replace('"', '""') 
        with get_db_connection() as conn:
            c = conn.cursor()
            c.execute(query, (f'"{q_safe}"',))
            rows = c.fetchall()
            
        for r in rows:
            parent_dir = os.path.dirname(r['path'])
            folder_name = os.path.basename(parent_dir) or parent_dir
            res.append({
                'path': r['path'], 
                'name': r['filename'], 
                'folder': folder_name, 
                'page': r['page_num'], 
                'snippet': r['snip']
            })
    except Exception as e:
        print(f"Search error: {e}")
        
    return jsonify(res)

@app.route('/fav_toggle', methods=['POST'])
def fav_toggle():
    path = request.json.get('path')
    added = False
    with get_db_connection() as conn:
        c = conn.cursor()
        try:
            c.execute("INSERT INTO favorites (path) VALUES (?)", (path,))
            added = True
        except sqlite3.IntegrityError:
            c.execute("DELETE FROM favorites WHERE path=?", (path,))
            added = False
        conn.commit()
    return jsonify({'status': 'ok', 'added': added})

@app.route('/fav_list')
def fav_list():
    with get_db_connection() as conn:
        c = conn.cursor()
        c.execute("SELECT path FROM favorites")
        res = [{'path': r['path'], 'name': os.path.basename(r['path'])} for r in c.fetchall()]
    return jsonify(res)

@app.route('/history_log', methods=['POST'])
def history_log():
    path = request.json.get('path')
    if path:
        with get_db_connection() as conn:
            c = conn.cursor()
            c.execute("INSERT OR REPLACE INTO history (path, last_access) VALUES (?, ?)", (path, time.time()))
            conn.commit()
    return jsonify({'status': 'ok'})

@app.route('/get_history')
def get_history():
    with get_db_connection() as conn:
        c = conn.cursor()
        c.execute("SELECT path FROM history ORDER BY last_access DESC LIMIT 50")
        res = [{'path': r['path'], 'name': os.path.basename(r['path'])} for r in c.fetchall()]
    return jsonify(res)

@app.route('/pdf')
def get_pdf():
    path = request.args.get('path')
    q = request.args.get('q', '')
    if not path or not os.path.exists(path):
        return "File not found", 404
        
    if q:
        try:
            doc = fitz.open(path)
            for page in doc:
                for inst in page.search_for(q):
                    page.add_highlight_annot(inst)
            pdf_bytes = doc.write()
            doc.close()
            return send_file(io.BytesIO(pdf_bytes), mimetype='application/pdf')
        except Exception:
            pass # Fallback to original file
            
    return send_from_directory(os.path.dirname(path), os.path.basename(path))

@app.route('/preview_page')
def preview_page():
    path = request.args.get('path')
    page_num = int(request.args.get('page', 1))
    q = request.args.get('q', '')
    
    if not path or not os.path.exists(path):
        return "File not found", 404

    try:
        doc = fitz.open(path)
        new_doc = fitz.open()
        new_doc.insert_pdf(doc, from_page=page_num-1, to_page=page_num-1)
        
        if q:
            for inst in new_doc[0].search_for(q):
                new_doc[0].add_highlight_annot(inst)
        
        pdf_bytes = new_doc.write()
        new_doc.close()
        doc.close()
        return send_file(io.BytesIO(pdf_bytes), mimetype='application/pdf')
    except Exception as e:
        print(f"Preview error: {e}")
        return "Error creating preview", 500

@app.route('/browse')
def browse():
    # Using subprocess to open a native dialog is cleaner for the server thread
    try:
        cmd = [sys.executable, '-c', 
               "import tkinter as tk, sys; from tkinter import filedialog; "
               "root=tk.Tk(); root.withdraw(); root.attributes('-topmost', True); "
               "path=filedialog.askdirectory(); root.destroy(); "
               "sys.stdout.buffer.write(path.encode('utf-8'))"]
        kwargs = {}
        if os.name == 'nt': 
            kwargs['creationflags'] = 0x08000000 
        
        path = subprocess.check_output(cmd, **kwargs).decode('utf-8').strip()
        if path: 
            return jsonify({'status': 'ok', 'path': path.replace('\\', '/')})
    except Exception as e: 
        print(f"Browse Error: {e}")
    
    return jsonify({'status': 'cancel'})

@app.route('/update_paths', methods=['POST'])
def update_paths():
    p = request.json.get('paths', [])
    current_config['root_paths'] = p
    save_config()
    return jsonify({'status': 'ok'})

@app.route('/save_prefs', methods=['POST'])
def save_prefs():
    data = request.json
    current_config['theme_mode'] = data.get('theme_mode', 0)
    save_config()
    return jsonify({'status': 'ok'})

@app.route('/get_tree')
def get_tree():
    if os.path.exists(TREE_FILE):
        try:
            with open(TREE_FILE, 'r', encoding='utf-8') as f:
                return jsonify(json.load(f))
        except: pass
    return jsonify({'tree': {'children': []}, 'flat': []})

@app.route('/get_init')
def get_init():
    return jsonify({'paths': current_config['root_paths'], 'theme_mode': current_config['theme_mode']})

@app.route('/heartbeat')
def heartbeat():
    server_status['last_heartbeat'] = time.time()
    return jsonify({'status': 'ok'})

# ---------------- Entry Point ----------------
if __name__ == '__main__':
    init_db()
    load_config()
    threading.Thread(target=monitor_shutdown, daemon=True).start()
    
    if not os.environ.get("WERKZEUG_RUN_MAIN"):
        open_app_window("http://127.0.0.1:5000")
        
    app.run(port=5000, debug=False)