#!/usr/bin/env python3
import os
import sys
import argparse
from pathlib import Path
import re
import databricks.sql as sql
from databricks.sdk import WorkspaceClient


CONFIG = {
    'dev': {
        'profile': 'documentation',
        'catalog': 'datalakehouse_liquibase_dev',
        'prod_catalog': 'datalakehouse_liquibase_prod',
        'dev_schemas': ['c3dbdv_stage_dev', 'c3dbdv_rawvault_dev', 'c3dbdv_businessvault_dev'],
        'prod_schemas': ['c3dbdv_stage_prod', 'c3dbdv_rawvault_prod', 'c3dbdv_businessvault_prod'],
        'sql_server_hostname': 'dbc-84749d04-bda5.cloud.databricks.com',
        'http_path': '/sql/protocolv1/o/131701153478878/0218-135919-gia7cjd9',
    },
    'output_dir': r'C:\\PROJECTS\\liquibase_demo'
}


def get_prod_schema_mapping():
    """Map dev schemas to prod schemas using index-based matching"""
    dev_schemas = CONFIG['dev']['dev_schemas']
    prod_schemas = CONFIG['dev']['prod_schemas']
    
    if len(dev_schemas) != len(prod_schemas):
        raise ValueError("dev_schemas and prod_schemas must have same length")
    
    return dict(zip(dev_schemas, prod_schemas))


def connect_databricks_sdk():
    return WorkspaceClient(profile=CONFIG['dev']['profile'])


def connect_databricks_sql():
    return sql.connect(
        server_hostname=CONFIG['dev']['sql_server_hostname'],
        http_path=CONFIG['dev']['http_path'],
        catalog=CONFIG['dev']['catalog'],
        schema="default",
        access_token=os.getenv('DATABRICKS_TOKEN')
    )


def list_views(w_sdk, sql_conn, catalog, schema):
    """
    Extract REAL view DDL with CREATE OR REPLACE VIEW using PROD catalog/schema
    AND replace DEV catalog references inside DDL body
    """
    views = []
    
    # List views using SDK
    tables = w_sdk.tables.list(catalog_name=catalog, schema_name=schema)
    view_names = []
    
    for t in tables:
        table_type_str = str(t.table_type)
        if table_type_str == "VIEW" or "view" in table_type_str.lower():
            view_names.append(t.name)
    
    print(f"      🛠️  Found {len(view_names)} views, fetching REAL DDL...")
    
    # Get PROD mapping
    schema_mapping = get_prod_schema_mapping()
    dev_catalog = CONFIG['dev']['catalog']
    prod_catalog = CONFIG['dev']['prod_catalog']
    prod_schema = schema_mapping.get(schema, schema)
    
    # Get REAL DDL using SHOW CREATE TABLE
    with sql_conn.cursor() as cursor:
        for view_name in view_names:
            try:
                full_name = f"`{catalog}`.`{schema}`.`{view_name}`"
                cursor.execute(f"SHOW CREATE TABLE {full_name}")
                row = cursor.fetchone()
                
                if row and row[0]:
                    real_ddl = row[0]
                    
                    # 1️⃣ Fix CREATE VIEW header: catalog.schema.viewname
                    real_ddl = re.sub(
                        r'CREATE\s+(OR REPLACE\s+)?VIEW\s+\S+',
                        f'CREATE OR REPLACE VIEW `{prod_catalog}`.`{prod_schema}`.`{view_name}`',
                        real_ddl,
                        flags=re.IGNORECASE
                    )
                    
                    # 2️⃣ Replace ALL occurrences of DEV catalog inside the DDL body
                    real_ddl = real_ddl.replace(dev_catalog, prod_catalog)
                    
                    # 3️⃣ Also replace DEV schemas with PROD schemas inside DDL body
                    schema_mapping = get_prod_schema_mapping()
                    for dev_schema, prod_schema_map in schema_mapping.items():
                        real_ddl = real_ddl.replace(f'`{dev_schema}`', f'`{prod_schema_map}`')
                        real_ddl = real_ddl.replace(dev_schema, prod_schema_map)
                    
                    views.append((prod_schema, view_name, real_ddl))
                    print(f"      ✅ EXTRACTED: {view_name} → {prod_catalog}.{prod_schema}.{view_name}")
                else:
                    print(f"      ⚠️  Empty DDL: {view_name}")
                    
            except Exception as e:
                print(f"      ⚠️  DDL failed for {view_name}: {str(e)[:100]}")
    
    return views


def generate_notebook_files(objects, catalog):
    """Generate Databricks notebook .py with TRY/CATCH per CREATE VIEW cell + exception summary"""
    if not objects:
        print("   ⚠️ No views found - skipping file generation")
        return
        
    prod_catalog = CONFIG['dev']['prod_catalog']
    Path(CONFIG['output_dir']).mkdir(exist_ok=True)

    notebook_path = os.path.join(CONFIG['output_dir'], "views_to_deploy_prod_notebook.py")
    
    with open(notebook_path, "w", encoding="utf-8") as f:
        # Header
        f.write("# Databricks notebook source\n")
        
        # Title cell
        f.write("# COMMAND ----------\n")
        f.write("%md\n")
        f.write("# 🚀 Deploy PROD Views from DEV (Error-Resilient)\n\n")
        f.write(f"**{len(objects)} views** | **Continues on ALL errors** ✅\n\n")
        f.write(f"**Target**: `{prod_catalog}` catalog\n")
        f.write("**Each cell**: TRY/CATCH → **never stops** on failures\n\n")
        f.write("---\n")
        
        # USE CATALOG cell
        f.write("# COMMAND ----------\n")
        f.write("%sql\n")
        f.write(f"USE CATALOG {prod_catalog};\n\n")
        
        # 🆕 INITIALIZE exception tracking list
        f.write("# COMMAND ----------\n")
        f.write("%python\n")
        f.write("# 📋 Initialize exception tracker\n")
        f.write("failed_views = []\n")
        f.write("total_views = 0\n\n")
        
        # One TRY/CATCH Python cell per CREATE VIEW
        for i, (schema, name, ddl) in enumerate(objects, 1):
            f.write("# COMMAND ----------\n")
            f.write("%python\n")
            f.write(f"# 📊 VIEW {i}/{len(objects)}: {prod_catalog}.{schema}.{name}\n")
            
            # ✅ FIXED: Pre-build the object name reference
            view_ref = f'{prod_catalog}.{schema}.{name}'
            
            f.write('total_views += 1\n')
            f.write('try:\n')
            f.write(f'    spark.sql("""\n')
            f.write(f'{ddl.strip()}\n')
            f.write('    """)\n')
            f.write(f'    print("✅ SUCCESS: {view_ref}")\n')
            f.write('except Exception as e:\n')
            f.write(f'    error_msg = f"⚠️ FAILED: {view_ref} | " + str(e)[:200]\n')
            f.write(f'    print(error_msg)\n')
            f.write(f'    failed_views.append("{view_ref}: " + str(e)[:200])\n')
            f.write('\n')
        
        # 🆕 SUMMARY CELL with exception list
        f.write("# COMMAND ----------\n")
        f.write("%md\n")
        f.write("## 📊 **DEPLOYMENT SUMMARY**\n\n")
        f.write(f"- **Total views**: {len(objects)}\n")
        f.write("- **Success**: `{{total_views - len(failed_views)}}`\n")
        f.write("- **Failures**: `{{len(failed_views)}}`\n\n")
        
        f.write("# COMMAND ----------\n")
        f.write("%python\n")
        f.write("# 🎯 FINAL EXCEPTION REPORT\n")
        f.write('print("\\n" + "="*80)\n')
        f.write('print("📋 FINAL SUMMARY")\n')
        f.write('print("="*80)\n')
        f.write('print(f"✅ SUCCESS: {total_views - len(failed_views)} / {total_views}")\n')
        f.write('print(f"❌ FAILED: {len(failed_views)} / {total_views}")\n')
        f.write('print("="*80)\n')
        f.write('\n')
        f.write('if failed_views:\n')
        f.write('    print("\\n📋 FAILED VIEWS LIST:")\n')
        f.write('    print("-" * 60)\n')
        f.write('    for i, failure in enumerate(failed_views, 1):\n')
        f.write('        print(f"{i:2d}. {failure}")\n')
        f.write('    print("\\n" + "-" * 60)\n')
        f.write('    print("🔧 ACTION REQUIRED:Re-run")\n')
        f.write('else:\n')
        f.write('    print("🎉 ALL VIEWS DEPLOYED SUCCESSFULLY! ✅")\n')
        f.write('\n')
        f.write('print("="*80)\n')
        f.write('# Next: Run Liquibase deployment\n')
    
    print(f"✅ Notebook created: {notebook_path}")
    print(f"📊 {len(objects)} TRY/CATCH cells + 1 SUMMARY cell generated")
    print("🚀 Import → Run All → **See ALL exceptions at the END!**")


def main():
    w_sdk = sql_conn = None
    try:
        print("🔌 Connecting to Databricks DEV...")
        w_sdk = connect_databricks_sdk()
        sql_conn = connect_databricks_sql()

        catalog = CONFIG['dev']['catalog']
        dev_schemas = CONFIG['dev']['dev_schemas']
        all_views = []

        print("🔍 Extracting views definitions...")
        for dev_schema in dev_schemas:
            print(f"   📁 Scanning DEV schema: {dev_schema}")
            views = list_views(w_sdk, sql_conn, catalog, dev_schema)
            all_views.extend(views)
            prod_schema = get_prod_schema_mapping().get(dev_schema, dev_schema)
            print(f"     → Will generate in PROD schema: {prod_schema}")
            print(f"     Found {len(views)} views")

        print(f"📊 Total: {len(all_views)} views")
        generate_notebook_files(all_views, catalog)
        
    except Exception as e:
        print(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()
    finally:
        if sql_conn:
            sql_conn.close()
        print("✅ Done.")


if __name__ == "__main__":
    main()
