#!/usr/bin/env python3
import snowflake.connector
import re
import os
import sys
from pathlib import Path
import argparse

CONFIG = {
    'dev': {
        'account': 'ud93525.eu-west-1',
        'user': 'TRAINING_LIQUIBASE_USER_01',
        'private_key_file': r'C:\TEMP\Snowflake_private_key\snowflake_key.p8',
        'private_key_passphrase': 'bigenius',
        'database': 'SNOWFLAKE_LIQUIBASE_USER_01',
        'warehouse': 'TRAINING_WH_LIQUIBASE',
        'role': 'R_SNOWFLAKE_LIQUIBASE_USER_01'
    },
    'schemas': ['SA', 'RDV', 'BV'],
    'output_dir': r'C:\PROJECTS\liquibase_demo'
}

def connect_snowflake():
    return snowflake.connector.connect(**CONFIG['dev'])

def fix_view_ddl(ddl, schema):
    pattern = r'CREATE OR REPLACE VIEW\s+(\w+)'
    replacement = f'CREATE OR REPLACE VIEW {schema}.\\g<1>'
    return re.sub(pattern, replacement, ddl, flags=re.IGNORECASE).strip()

def fix_procedure_ddl(ddl, schema):
    db_name = CONFIG['dev']['database']  
    full_schema = f"{db_name}.{schema}"  

    pattern = r'(CREATE OR REPLACE PROCEDURE\s+)(["\']?[^("\']+["\']?)(\s*\()'
    replacement = rf'\1{full_schema}.\2\3'
    fixed_ddl = re.sub(pattern, replacement, ddl, flags=re.IGNORECASE)
    
    # NEW: Remove precision from TIMESTAMP_TZ and similar types
    fixed_ddl = re.sub(r'TIMESTAMP_TZ\(\s*\d+\s*\)', 'TIMESTAMP_TZ', fixed_ddl, flags=re.IGNORECASE)
    
    return fixed_ddl.strip()


def extract_views(cursor):
    cursor.execute("""
        SELECT TABLE_SCHEMA, TABLE_NAME
        FROM INFORMATION_SCHEMA.VIEWS 
        WHERE TABLE_SCHEMA IN (%s, %s, %s)
          AND TABLE_CATALOG = %s
        ORDER BY TABLE_SCHEMA, TABLE_NAME
    """, CONFIG['schemas'] + [CONFIG['dev']['database']])
    
    objects = []
    for schema, name in cursor.fetchall():
        cursor.execute("SELECT GET_DDL('VIEW', %s || '.' || %s)", (schema, name))
        ddl = cursor.fetchone()[0]
        fixed_ddl = fix_view_ddl(ddl, schema)
        objects.append((schema, name, fixed_ddl))
    return objects

def extract_procedures(cursor):
    cursor.execute("""
        SELECT PROCEDURE_SCHEMA, PROCEDURE_NAME, ARGUMENT_SIGNATURE
        FROM INFORMATION_SCHEMA.PROCEDURES 
        WHERE PROCEDURE_SCHEMA IN (%s, %s, %s)
          AND PROCEDURE_CATALOG = %s
        ORDER BY PROCEDURE_SCHEMA, PROCEDURE_NAME
    """, CONFIG['schemas'] + [CONFIG['dev']['database']])
    
    objects = []
    print("🔍 Extracting 28 procedures...")
    
    for schema, proc_name, arg_signature in cursor.fetchall():
        try:
            if arg_signature == '()':
                full_name = f"{schema}.{proc_name}()"
            else:
                # Remove parameter names: "(param1 TYPE1, param2 TYPE2)" → "(TYPE1, TYPE2)"
                types_only = re.sub(r'\b\w+\s+', '', arg_signature)
                types_only = types_only.strip('() ')
                full_name = f"{schema}.{proc_name}({types_only})"
            
            cursor.execute("SELECT GET_DDL('PROCEDURE', %s)", (full_name,))
            ddl = cursor.fetchone()[0]
            fixed_ddl = fix_procedure_ddl(ddl, schema)
            objects.append((schema, proc_name, fixed_ddl))
            
        except Exception as e:
            print(f"   ❌ SKIP {schema}.{proc_name}: {e}")
            continue
    
    return objects

def generate_sql_files(objects, object_type):
    dev_content = [f"-- {len(objects)} {object_type} from DEV with schema prefixes"]
    dev_content.append(f"USE DATABASE {CONFIG['dev']['database']};")
    dev_content.append("")
    
    for schema, name, ddl in objects:
        dev_content.append(f"-- {schema}.{name}")
        dev_content.append(ddl + ";")
        dev_content.append("")
    
    dev_sql = '\n'.join(dev_content)
    Path(CONFIG['output_dir']).mkdir(exist_ok=True)
    
    # PROD file (only database name changed)
    prod_sql = dev_sql.replace(
        'SNOWFLAKE_LIQUIBASE_USER_01', 
        'SNOWFLAKE_LIQUIBASE_PROD_USER_01'
    )
    prod_path = os.path.join(CONFIG['output_dir'], f'{object_type}_to_deploy_prod.sql')
    with open(prod_path, 'w', encoding='utf-8') as f:
        f.write(prod_sql)
    
    print(f"✅ Generated {len(objects)} {object_type}s WITH SCHEMA PREFIXES:")
    print(f"   {prod_path}")

def main(object_type):
    conn = None
    try:
        print(f"🔌 Connecting to Snowflake DEV...")
        conn = connect_snowflake()
        cursor = conn.cursor()
        
        print(f"🔍 Extracting & fixing {object_type} definitions...")
        if object_type.lower() == 'views':
            objects = extract_views(cursor)
        elif object_type.lower() == 'procedures':
            objects = extract_procedures(cursor)
        else:
            raise ValueError("object_type must be 'views' or 'procedures'")
        
        print(f"📊 Fixed {len(objects)} {object_type} (schema prefixes added)")
        generate_sql_files(objects, object_type)
        
    except Exception as e:
        print(f"❌ Error: {e}")
    finally:
        if conn:
            conn.close()
            print("🔌 Connection closed")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Extract Snowflake Views or Procedures")
    parser.add_argument("object_type", choices=['views', 'procedures'], 
                       help="Extract views or procedures")
    args = parser.parse_args()
    
    main(args.object_type)
