#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Recursively copy or sync files from source to destination, skipping rejected patterns.
Supports local filesystem, HTTP(S), and FTP(S) sources.
"""

import argparse
import os
import posixpath
import re
import shutil
import sys
import urllib.parse
from datetime import datetime
from email.utils import parsedate_to_datetime
import ftplib

# 外部ライブラリが必要です: pip install requests beautifulsoup4
try:
    import requests
    from bs4 import BeautifulSoup
except ImportError:
    print("Error: 'requests' and 'beautifulsoup4' libraries are required.", file=sys.stderr)
    print("Please install them using: pip install requests beautifulsoup4", file=sys.stderr)
    sys.exit(1)


# --- グローバル設定 ---
COE_tkProg_source = r"\\192.168.27.2\share\apps\tkProg\tklib"
COE_tkProg_dest   = "tklib"

debug = True


# 拒否するパスの正規表現文字列のリスト（相対パス・スラッシュ区切りで判定）
reject_patterns = [
    r"__pycache__$",
    r"\.pyc$",
    r"\.prev$",
    r"\.junk$",
    r" - コピー",
    r"kamiya",
    r"personal",
    r"^QE\/",
    r"^QE$",
#    r"^VASP\/",
#    r"^VASP$",
]


# --- ユーティリティ（URL正規化/相対計算）---

def is_url(s: str) -> bool:
    s = s.lower()
    return s.startswith(("http://", "https://", "ftp://", "ftps://"))

def normalize_url(url: str) -> str:
    """
    http(s)/ftp(s) の URL を正規化して返す。
    - バックスラッシュ(\\) → スラッシュ(/)
    - パスの // を 1本に畳み込み
    - ディレクトリ想定で末尾に / を付与
    """

    if re.search(r'\w//', url):
        print(f"\nError in normalize_url(): Invalid separator '//' in [{url}]\n")
        exit()
        
    if not is_url(url):
        return url
    url = url.replace("\\", "/")
    parts = urllib.parse.urlsplit(url)
    path = parts.path or "/"
    # 重複 / を畳み込み（先頭の / は保持）
    path = re.sub(r"/+", "/", path)
    if not path.endswith("/"):
        path += "/"
    return urllib.parse.urlunsplit((parts.scheme, parts.netloc, path, parts.query, parts.fragment))

def relpath_url(path: str, base: str) -> str:
    """
    URL のパス同士の相対パスを POSIX で計算（常に / 区切り）。
    os.path.relpath は OS 依存で \\ を返す可能性があるため、URL では使わない。
    """
    
    path = (path or "/").replace("\\", "/")
    base = (base or "/").replace("\\", "/")
    if not path.startswith("/"):
        path = "/" + path
    if not base.startswith("/"):
        base = "/" + base
    r = posixpath.relpath(path, base)
    return r


# --- 基本関数 ---
def parse_args():
    """コマンドライン引数を解析する"""
    parser = argparse.ArgumentParser(
        description="Recursively copy or sync files from source to destination, skipping rejected patterns. Supports local, http(s), and ftp(s) sources."
    )
    parser.add_argument("-s", "--source", type=str, default=COE_tkProg_source, help="Source directory or URL")
    parser.add_argument("-d", "--dest", type=str, default=COE_tkProg_dest, help="Destination directory")
    parser.add_argument("-m", "--maxlevel", type=int, default=-1, help="Maximum recursion depth (-1 for infinite)")
    parser.add_argument("--dry-run", action="store_true", help="List orphan files that would be deleted by --mirror, without deleting.")
    parser.add_argument("--mirror", action="store_true", help="Enable mirror mode: delete orphan files in destination.")
    return parser.parse_args()

def handle_user_choice(prompt):
    """ユーザーに対話的な選択を求める"""
    print(prompt)
    print("Options: [Y]es, [N]o, [A]ll, [R]etain All, [S]top")
    while True:
        choice = input("Your choice (y/n/a/r/s): ").lower()
        if choice in ['y', 'n', 'a', 'r', 's']:
            return choice
        print("Invalid input, please enter 'y', 'n', 'a', 'r', or 's'.")

def should_reject(path, compiled_patterns):
    """パスが拒否パターンに一致するかチェックする（区切りは常に / として扱う）"""
    normalized_path = path.replace("\\", "/")
    return any(pattern.search(normalized_path) for pattern in compiled_patterns)


# --- ディレクトリ走査 (Walker) 関数群 ---
def walk_local(source, compiled_patterns, maxlevel=-1):
    """ローカルファイルシステムを走査するジェネレータ"""
    source_root_depth = source.rstrip(os.sep).count(os.sep)
    for root, dirs, files in os.walk(source, topdown=True):
        current_depth = root.rstrip(os.sep).count(os.sep) - source_root_depth
        if maxlevel != -1 and current_depth > maxlevel:
            del dirs[:]
            continue

        # ディレクトリの枝刈り (相対パスで判定)
        dirs[:] = [d for d in dirs
                   if not should_reject(os.path.relpath(os.path.join(root, d), source), compiled_patterns)]

        file_info_list = [{'name': f, 'mtime': os.path.getmtime(os.path.join(root, f)), 'dir': root} for f in files]
        yield root, dirs, file_info_list

def walk_http(current_url, source_base_path, compiled_patterns, maxlevel=-1, current_level=0):
    """HTTPサーバーを走査するジェネレータ（Apache等のディレクトリリストを想定）"""
    if maxlevel != -1 and current_level > maxlevel:
        return

#    print("walk_http: current_url:", current_url)
#    print("      source_base_path:", source_base_path)
    try:
#response = requests.get(url, auth=(username, password))
        response = requests.get(current_url)
        response.raise_for_status()
    except requests.RequestException as e:
        print(f"Error in walk_http(): Faled to access URL [{current_url}]: {e}", file=sys.stderr)
        if debug: exit()
        return

    soup = BeautifulSoup(response.text, 'html.parser')
    current_path = urllib.parse.urlparse(current_url).path

    dirs, files = [], []
    for link in soup.find_all('a'):
        href = link.get('href')
        if not href: continue

        # \ を / に正規化（サーバが変なリンクを出す場合に備える）
        href = href.replace('\\', '/')

        # 余計なケースをスキップ
        if href.startswith('?') or href.startswith('#') or href in ('.', '..', './', '../'):
            continue


        # 絶対または相対リンクを解決
#        print("   href:", href)
        full_url = urllib.parse.urljoin(current_url, href)
        parsed_full = urllib.parse.urlparse(full_url)
#        print("     full_url: ", full_url)

        mtime = 0
        try:
#response = requests.get(url, auth=(username, password))
            head_resp = requests.head(full_url)
            if 'Last-Modified' in head_resp.headers:
                mtime = parsedate_to_datetime(head_resp.headers['Last-Modified']).timestamp()
        except requests.RequestException:
            pass

        if href.endswith('/'):
            # ディレクトリ
#            dir_path = parsed_full.path
#            rel_path = relpath_url(dir_path, source_base_path)
            dir_path = full_url
            rel_path = relpath_url(dir_path, current_url)
            if rel_path in ('.', '..', './', '../'): continue

#            print("dir_path:", dir_path)
#            print("  rel_path:", rel_path)
            if not should_reject(rel_path, compiled_patterns):
                # 以後の join は urljoin を使うため、末尾 / を落とした表示名で十分
#                name = href.strip('/')
                name = rel_path.strip('/')
#                print("***name=", name)
                if name:
                    dirs.append(name)
        else:
            # ファイル
            name = relpath_url(full_url, current_url)
#            name = posixpath.basename(parsed_full.path)
#            name = urllib.parse.unquote(posixpath.basename(parsed_full.path))
            name = urllib.parse.unquote(name)
            if not name: continue
            files.append({'name': name, 'mtime': mtime, 'dir': current_url})

    yield current_url, dirs, files

    if maxlevel == -1 or current_level < maxlevel:
        for d in dirs:
            # URL の join は常に / で行われる
            next_url = urllib.parse.urljoin(current_url, d + '/')
#            print("  current_url:", current_url)
#            print("         d:", d)
#            print("          next_url:", next_url)
            yield from walk_http(next_url, source_base_path, compiled_patterns, maxlevel, current_level + 1)

def walk_ftp(ftp_conn, current_path, source_base_path, compiled_patterns, maxlevel=-1, current_level=0):
    """FTPサーバーを走査するジェネレータ"""
    if maxlevel != -1 and current_level > maxlevel:
        return

    # FTP は POSIX 区切りを想定、\ を / に正規化
    current_path = (current_path or "/").replace("\\", "/")
    source_base_path = (source_base_path or "/").replace("\\", "/")
    if not current_path.startswith("/"):
        current_path = "/" + current_path

    try:
        original_cwd = ftp_conn.pwd()
        ftp_conn.cwd(current_path)
    except ftplib.error_perm as e:
        print(f"Error in walk_ftp(): Failed to access FTP path {current_path}: {e}", file=sys.stderr)
        if debug: exit()
        return

    items = []
    try:
        items = ftp_conn.nlst()
    except ftplib.error_perm as e:
        print(f"Warning in walk_ftp(): Could not list items in {current_path}: {e}", file=sys.stderr)
        if debug: exit()

    dirs, files = [], []
    for item_name in items:
        if item_name in ('.', '..'):
            continue

        mtime = 0
        is_dir = False
        try:
            # ディレクトリ判定（CWD 試行）
            ftp_conn.cwd(item_name)
            is_dir = True
            ftp_conn.cwd('..')
        except ftplib.error_perm:
            is_dir = False
            # ファイルなら MDTM を試す
            try:
                mdtm_str = ftp_conn.voidcmd(f'MDTM {item_name}')[4:].strip()
                mtime = datetime.strptime(mdtm_str, '%Y%m%d%H%M%S').timestamp()
            except (ftplib.error_perm, ValueError):
                pass

        if is_dir:
            full_item_path = f"{current_path.rstrip('/')}/{item_name}"
            rel_path = relpath_url(full_item_path, source_base_path)
            if not should_reject(rel_path, compiled_patterns):
                dirs.append(item_name)
        else:
            files.append({'name': item_name, 'mtime': mtime, 'dir': current_path})

    yield current_path, dirs, files
    ftp_conn.cwd(original_cwd)

    if maxlevel == -1 or current_level < maxlevel:
        for d in dirs:
            next_path = f"{current_path.rstrip('/')}/{d}"
            yield from walk_ftp(ftp_conn, next_path, source_base_path, compiled_patterns, maxlevel, current_level + 1)


# --- メインロジック ---
def sync_files(source, dest, compiled_patterns, maxlevel, protocol, dry_run, mirror):
    """ファイルを再帰的に同期し、必要に応じてミラーリングする"""
    global_user_choice = None

    source_rel_paths = set()  # 同期の対象となるソースの相対パスを格納（常に / 区切り）

    print(f"\n[Sync Phase] Scanning source and syncing to [{dest}]...")

    # --- Walkerのセットアップ ---
    walker = None
    ftp_conn = None
    source_base_path_for_relpath = source

    if protocol == 'local':
        walker = walk_local(source, compiled_patterns, maxlevel)
    else:  # http, ftp
        parsed_url = urllib.parse.urlparse(source)
#        source_base_path_for_relpath = parsed_url.path or '/'
        source_base_path_for_relpath = ''
        if protocol.startswith('http'):
            walker = walk_http(source, source_base_path_for_relpath, compiled_patterns, maxlevel)
        elif protocol.startswith('ftp'):
            try:
                use_tls = (protocol == 'ftps')
                FTP_CLASS = ftplib.FTP_TLS if use_tls else ftplib.FTP
                ftp_conn = FTP_CLASS()
                ftp_conn.connect(parsed_url.hostname, parsed_url.port or (990 if use_tls else 21))
                ftp_conn.login(parsed_url.username or "anonymous", parsed_url.password or "")
                if use_tls:
                    ftp_conn.prot_p()
                walker = walk_ftp(ftp_conn, source_base_path_for_relpath, source_base_path_for_relpath, compiled_patterns, maxlevel)
            except Exception as e:
                print(f"FTP connection failed: {e}", file=sys.stderr)
                if ftp_conn:
                    ftp_conn.close()
                sys.exit(1)

    if not walker:
        print("Could not determine a valid walker for the source.", file=sys.stderr)
        sys.exit(1)

    # --- メイン同期ループ ---
    for root, dirs, files in walker:
        if debug:
            print("root:", root)
            print("dirs:", dirs)

        if protocol == 'local':
           rel_dir = os.path.relpath(root, source)
        else:
            # URL/FTP：相対は POSIX で計算（\ が混ざらないようにする）
#            current_path = urllib.parse.urlparse(root).path if protocol.startswith('http') else root
#            rel_dir = relpath_url(current_path, source_base_path_for_relpath)
            rel_dir = relpath_url(root, source)
#            print("root:", root)
#            print("current_path:", current_path)
#            print("rel_dir:", rel_dir)

        dest_dir = os.path.join(dest, rel_dir) if rel_dir != "." else dest
        if not os.path.exists(dest_dir):
            print(f"** Creating dir [{dest_dir}]")
        os.makedirs(dest_dir, exist_ok=True)

        # 1. dest内のreject対象を対話的に削除
        if os.path.isdir(dest_dir):
            for item in list(os.listdir(dest_dir)):
                dest_item_path = os.path.join(dest_dir, item)
                rel_item_path = os.path.relpath(dest_item_path, dest).replace(os.sep, '/')
                if should_reject(rel_item_path, compiled_patterns):
                    if global_user_choice == 'a':
                        if os.path.isdir(dest_item_path):
                            shutil.rmtree(dest_item_path)
                        else:
                            os.remove(dest_item_path)
                        print(f"Deleted '{dest_item_path}' (All mode).")
                        continue
                    elif global_user_choice == 'r':
                        print(f"Kept '{dest_item_path}' (Retain All mode).")
                        continue
                    prompt = f"'{dest_item_path}' matches a reject pattern. Delete it?"
                    choice = handle_user_choice(prompt)
                    if choice in ('y', 'a'):
                        if os.path.isdir(dest_item_path):
                            shutil.rmtree(dest_item_path)
                        else:
                            os.remove(dest_item_path)
                        print(f"Deleted '{dest_item_path}'.")
                        if choice == 'a':
                            global_user_choice = 'a'
                    elif choice in ('n', 'r'):
                        print(f"Kept '{dest_item_path}'.")
                        if choice == 'r':
                            global_user_choice = 'r'
                    elif choice == 's':
                        print("Stopping program.")
                        sys.exit(0)

        # 2. sourceからファイル/ディレクトリ相対パスを記録（常に / 区切り）
        current_rel_dir = "" if rel_dir == "." else rel_dir
        for d in dirs:
            # os.path.join は Windows で \ を使うため、最後に / に正規化
#            p = os.path.join(current_rel_dir, d).replace(os.sep, '/')
            p = d
            source_rel_paths.add(p)

        # 3. ファイルをコピー/更新
        print()
        print(f"Copy path candidates in [{root}]: ", )
        for file_info in files:
            file_rel_path = file_info['name'].replace(os.sep, '/')
#            file_rel_path = os.path.join(current_rel_dir, file_info['name']).replace(os.sep, '/')
#            print("381: file_rel_path: ", file_rel_path)
            reject = should_reject(file_rel_path, compiled_patterns)
            if reject: continue

            source_rel_paths.add(file_rel_path)
            dest_path = os.path.join(dest_dir, file_info['name'])
            src_mtime = file_info['mtime']
            if os.path.exists(dest_path):
                dest_mtime = os.path.getmtime(dest_path)
            else:
                dest_mtime = None

#            print("dest_dir:", dest_dir)
#            print("file_info['name']:",file_info['name'])
#            print("dest_path:",dest_path)
#            print("src_mtime:", src_mtime)
#            print("dest_mtime:", dest_mtime)
            if os.path.exists(dest_path) and (src_mtime and dest_mtime >= src_mtime):
                print(f"Target file '{file_rel_path}' is newer. Skip")
#            if (not os.path.exists(dest_path)) or (src_mtime and os.path.getmtime(dest_path) < src_mtime):
            else:
                print(f"Syncing [{file_rel_path}] to [{dest_path}]...", end = '')
                try:
                    if protocol == 'local':
                        shutil.copy2(os.path.join(root, file_info['name']), dest_path)
                    elif protocol.startswith('http'):
                        src_url = urllib.parse.urljoin(root, file_info['name'])
#response = requests.get(url, auth=(username, password))
                        with requests.get(src_url, stream=True) as r:
                            r.raise_for_status()
                            with open(dest_path, 'wb') as f:
                                shutil.copyfileobj(r.raw, f)
                        if src_mtime:
                            os.utime(dest_path, (src_mtime, src_mtime))
                    elif protocol.startswith('ftp'):
                        remote_path = f"{root.rstrip('/')}/{file_info['name']}"
                        with open(dest_path, 'wb') as f:
                            ftp_conn.retrbinary(f'RETR {remote_path}', f.write)
                        if src_mtime:
                            os.utime(dest_path, (src_mtime, src_mtime))
                    print(f"  -> Synced.")
                except Exception as e:
                    print(f"  -> FAILED to sync: {e}", file=sys.stderr)

    if ftp_conn:
        try:
            ftp_conn.quit()
        except Exception:
            pass
    print(f"\n[Sync Phase Complete] Found {len(source_rel_paths)} valid files/dirs in source.")

    # --- 4. ミラーリング処理 ---
    if dry_run or mirror:
        phase_name = 'Mirror' if mirror else 'Mirror Dry Run'
        print(f"\n[{phase_name}] Checking for orphan files in destination (respecting maxlevel={maxlevel})...")

        deleted_count = 0
        found_count = 0

        # コピー先の深度を計算するために基準値を取得
        dest_root_depth = dest.rstrip(os.sep).count(os.sep)

        # topdown=False のため、空ディレクトリ削除が安全
        for root, dirs, files in os.walk(dest, topdown=False):

            # 現在の階層の深度をチェック
            current_depth = root.rstrip(os.sep).count(os.sep) - dest_root_depth
            if maxlevel != -1 and current_depth > maxlevel:
                continue  # maxlevelを超える階層はチェック対象外

            # --- ファイルのチェック ---
            for name in files:
                dest_path = os.path.join(root, name)
                rel_path = os.path.relpath(dest_path, dest).replace(os.sep, '/')
                if rel_path not in source_rel_paths:
                    if mirror:
                        print(f"Deleting orphan file: {rel_path}")
                        try:
                            os.remove(dest_path)
                            deleted_count += 1
                        except Exception as e:
                            print(f"  -> FAILED to delete: {e}", file=sys.stderr)
                    else:
                        print(f"Found orphan file: {rel_path}")
                        found_count += 1

            # --- ディレクトリのチェック (空のもののみ) ---
            if maxlevel == -1 or current_depth < maxlevel:
                for name in dirs:
                    dest_path = os.path.join(root, name)
                    rel_path = os.path.relpath(dest_path, dest).replace(os.sep, '/')
                    if rel_path not in source_rel_paths and not os.listdir(dest_path):
                        if mirror:
                            print(f"Deleting orphan directory: {rel_path}")
                            try:
                                os.rmdir(dest_path)
                                deleted_count += 1
                            except Exception as e:
                                print(f"  -> FAILED to delete: {e}", file=sys.stderr)
                        else:
                            print(f"Found orphan directory: {rel_path}")
                            found_count += 1

        if mirror:
            print(f"[{phase_name} Complete] Deleted {deleted_count} orphan items.")
        else:
            print(f"[{phase_name} Complete] Found {found_count} orphan items.")


def main():
    """メイン関数"""
    args = parse_args()

    # URLなら正規化（\ → /、重複 / 畳み込み、末尾 / 付与）
    source = normalize_url(args.source) if is_url(args.source) else args.source
    dest = os.path.abspath(args.dest)

    protocol = 'local'
    if is_url(source):
        lower = source.lower()
        if lower.startswith(('https://', 'http://')):
            protocol = 'http'
        elif lower.startswith(('ftps://', 'ftp://')):
            protocol = 'ftps' if lower.startswith('ftps') else 'ftp'
    else:
        if not os.path.isdir(source):
            print(f"Error: Source directory '{source}' not found or is not a directory.", file=sys.stderr)
            sys.exit(1)

    print(f"\nStarting file sync from '{source}' to '{dest}'")
    print(f"  - Protocol: {protocol}")
    print(f"  - Max Recursion Depth: {'Infinite' if args.maxlevel == -1 else args.maxlevel}")
    print(f"  - Dry run: {'Enabled' if args.dry_run else 'Disabled'}")
    print(f"  - Mirror: {'Enabled' if args.mirror else 'Disabled'}")
    print(f"Source     : [{source}]")
    print(f"Destination: [{dest}]")

    compiled_patterns = [re.compile(p, re.IGNORECASE) for p in reject_patterns]
    sync_files(source, dest, compiled_patterns, args.maxlevel, protocol, args.dry_run, args.mirror)
    print("\nSync process finished.\n")


if __name__ == "__main__":
    main()
