from pybit.unified_trading import HTTP
from datetime import datetime, timezone
import csv
import os
from enum import Enum

class KlineInterval(Enum):
    """Kline 간격을 문자열 형식으로 정의하고, 밀리초 단위 시간 간격을 계산하는 Enum 클래스"""
    M1 = "1"
    M3 = "3"
    M5 = "5"
    M15 = "15"
    M30 = "30"
    H1 = "60"
    H2 = "120"
    H4 = "240"
    H6 = "360"
    H12 = "720"
    D1 = "D"
    W1 = "W"

    def __str__(self):
        return self.value

    def to_milliseconds(self):
        intervals_in_minutes = {
            KlineInterval.M1: 1,
            KlineInterval.M3: 3,
            KlineInterval.M5: 5,
            KlineInterval.M15: 15,
            KlineInterval.M30: 30,
            KlineInterval.H1: 60,
            KlineInterval.H2: 120,
            KlineInterval.H4: 240,
            KlineInterval.H6: 360,
            KlineInterval.H12: 720,
            KlineInterval.D1: 1440,
            KlineInterval.W1: 10080
        }
        return intervals_in_minutes[self] * 60 * 1000

def utc_datetime_to_epoch_ms(datetime_str, time_format="%Y.%m.%d_%H.%M.%S"):
    dt_utc = datetime.strptime(datetime_str, time_format).replace(tzinfo=timezone.utc)
    return int(dt_utc.timestamp()) * 1000

def cy_get_kline(category, symbol, interval: KlineInterval, start_epoch_ms, end_epoch_ms):
    response = session.get_kline(
        category=category,
        symbol=symbol,
        interval=str(interval),
        start=start_epoch_ms,
        end=end_epoch_ms,
        limit=1000
    )
    
    if response["retCode"] != 0:
        raise Exception(f"API Error: {response['retMsg']}")
    
    return response["result"]["list"]

def save_to_csv(file_name, data, write_header=False):
    with open(file_name, mode='a', newline='') as file:
        writer = csv.writer(file)
        if write_header:
            writer.writerow(["timestamp", "open", "high", "low", "close", "volume", "turnover"])
        writer.writerows(data)

def generate_csv_filename_bybit(category, symbol, interval: KlineInterval, start_str, end_str):
    return f'bybit_{category}_{symbol}_{interval.name}_{start_str}_{end_str}.csv'

def delete_existing_file(file_name):
    if os.path.exists(file_name):
        os.remove(file_name)
        print(f"{file_name} 파일이 삭제되었습니다.")

def main():
    global session
    session = HTTP(testnet=False)

    # 데이터 요청할 심볼 및 시간 간격 지정
    category = "linear"
    symbol = "BTCUSDT"
    interval = KlineInterval.M5

    # 시간 범위 설정
    utc_datetime_start_str = "2024.08.01_00.00.00"
    utc_datetime_end_str = "2024.08.21_00.00.00"

    start_ms = utc_datetime_to_epoch_ms(utc_datetime_start_str)
    end_ms = utc_datetime_to_epoch_ms(utc_datetime_end_str)

    total_number_of_candle_mustbe = 1 + (end_ms - start_ms) // interval.to_milliseconds()

    # CSV 파일명 생성
    csv_file = generate_csv_filename_bybit(category, symbol, interval, utc_datetime_start_str, utc_datetime_end_str)
    delete_existing_file(csv_file)

    first_write = True
    total_number_of_candle_acquired = 0

    while True:
        try:
            candle_list = cy_get_kline(category, symbol, interval, start_ms, end_ms)
        except Exception as e:
            print(f"Stop Processing. {str(e)}")
            break

        if not candle_list:
            print("Complete Candle Acquisition. the number of candle 0")
            break

        total_number_of_candle_acquired += len(candle_list)

        print(f"First candle: {candle_list[0]}")
        print(f"Last candle: {candle_list[-1]}")

        save_to_csv(csv_file, candle_list, first_write)

        first_write = False

        oldest_candle_epoch_ms = int(candle_list[-1][0])
        end_ms = oldest_candle_epoch_ms - interval.to_milliseconds()

        if end_ms < start_ms:
            print("Completed Candle Acquisition")
            print(f"Total number of acquired candles = {total_number_of_candle_acquired}")
            print(f"The number of candle must be {total_number_of_candle_mustbe}")
            print(f"The missing number of candle = {total_number_of_candle_mustbe - total_number_of_candle_acquired}")
            break

if __name__ == "__main__":
    main()