import threading
import time
import json
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.chrome.options import Options
from selenium.common.exceptions import TimeoutException, WebDriverException
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from urllib.parse import urlparse, urlunparse, parse_qs, urlencode

# 存储结果的列表
results = []
successful_requests = 0
failed_requests = 0
lock = threading.Lock()


def worker(target_url_orig, method, request_data_str, user_agent, chromedriver_path, total_requests_per_thread):
    """
    工作线程函数,使用 Selenium WebDriver 向指定的 URL 发送请求。
    """
    global successful_requests, failed_requests, results

    options = Options()
    if user_agent:
        options.add_argument(f"user-agent={user_agent}")
    options.add_argument("--headless")
    options.add_argument("--disable-gpu")
    options.add_argument("--no-sandbox")
    options.add_argument("--disable-dev-shm-usage")
    options.add_argument("--log-level=3")  # Suppress console logs from Chrome/ChromeDriver
    options.add_experimental_option('excludeSwitches', ['enable-logging'])

    webdriver_service = Service(executable_path=chromedriver_path) if chromedriver_path else Service()
    driver = None  # Initialize driver to None for the finally block

    try:
        # --- WebDriver instantiation moved outside the loop ---
        driver = webdriver.Chrome(service=webdriver_service, options=options)
        driver.set_page_load_timeout(30)  # Increased timeout for browser page load
        driver.set_script_timeout(30)  # Timeout for async JavaScript

        for i in range(total_requests_per_thread):
            # This inner try-except block handles errors for a single request iteration
            # allowing the loop (and the browser instance) to continue if possible.
            try:
                start_time = time.time()
                # --- Logic for current_url_to_load moved inside the loop ---
                current_url_to_load = target_url_orig  # Reset to base URL for each iteration

                if method.upper() in ["GET", "DELETE"] and request_data_str:
                    # This logic will now correctly apply params for each GET/DELETE request in the loop
                    temp_url_for_params = target_url_orig  # Start with the original URL for param addition
                    try:
                        params_dict = json.loads(request_data_str)
                        if isinstance(params_dict, dict):
                            parsed_url = urlparse(temp_url_for_params)
                            query_params = parse_qs(parsed_url.query)
                            query_params.update({k: [str(v)] for k, v in params_dict.items()})
                            new_query_string = urlencode(query_params, doseq=True)
                            current_url_to_load = urlunparse(parsed_url._replace(query=new_query_string))
                        else:  # Not a dict, treat as simple string if json.loads didn't fail (e.g. "[]" or "\"string\"")
                            # This case might need refinement based on desired behavior for non-dict JSON params
                            if '?' not in temp_url_for_params:
                                current_url_to_load = temp_url_for_params + '?' + request_data_str
                            else:
                                current_url_to_load = temp_url_for_params + '&' + request_data_str
                    except json.JSONDecodeError:
                        # Not JSON, assume it's a query string like "key1=value1&key2=value2"
                        if '?' not in temp_url_for_params:
                            current_url_to_load = temp_url_for_params + '?' + request_data_str
                        else:
                            current_url_to_load = temp_url_for_params + '&' + request_data_str

                driver.get(current_url_to_load)

                # Wait for potential Cloudflare challenge to resolve
                wait = WebDriverWait(driver, 25)  # Wait up to 25 seconds for CF
                try:
                    wait.until_not(EC.title_contains("Just a moment"))
                    wait.until_not(EC.title_contains("Checking your browser"))
                    wait.until_not(EC.title_contains("Verifying you are human"))
                    wait.until_not(EC.title_contains("DDoS protection by Cloudflare"))
                except TimeoutException:
                    raise WebDriverException("Cloudflare challenge page timed out.")

                response_status = 200  # Assume success if page loads and not CF
                response_text_snippet = driver.title  # Use title as a simple snippet

                if method.upper() in ["POST", "PUT"]:
                    if not request_data_str:
                        print(f"警告: {method} 请求没有提供数据。将仅访问 URL。")
                    else:
                        content_type = 'application/x-www-form-urlencoded;charset=UTF-8'
                        body_to_send = request_data_str
                        try:
                            json.loads(request_data_str)  # Check if it's valid JSON
                            content_type = 'application/json;charset=UTF-8'
                        except (json.JSONDecodeError, TypeError):
                            pass  # Keep as form-urlencoded or other

                        js_fetch_script = f"""
                        let callback = arguments[arguments.length - 1];
                        fetch(arguments[0], {{
                            method: arguments[1],
                            headers: {{'Content-Type': '{content_type}'}},
                            body: arguments[2]
                        }})
                        .then(response => {{
                            return response.text().then(text => {{
                                callback({{status: response.status, text: text, error: null}});
                            }});
                        }})
                        .catch(error => {{
                            callback({{status: -1, text: null, error: error.toString()}});
                        }});
                        """
                        fetch_result = driver.execute_async_script(js_fetch_script, target_url_orig, method.upper(),
                                                                   body_to_send)

                        if fetch_result.get("error"):
                            raise WebDriverException(f"JavaScript Fetch 错误: {fetch_result['error']}")

                        response_status = fetch_result.get("status", -1)
                        response_text_snippet = (fetch_result.get("text") or "")[:100]

                end_time = time.time()
                duration = end_time - start_time

                with lock:
                    results.append(duration)
                    if 200 <= response_status < 300:
                        successful_requests += 1
                    else:
                        failed_requests += 1
                        print(
                            f"请求失败 (请求 {i + 1}/{total_requests_per_thread}): URL={target_url_orig}, 方法={method}, 状态码={response_status}, 响应片段='{response_text_snippet}'")

            except TimeoutException:  # This exception is per-request
                with lock:
                    failed_requests += 1
                    results.append(float('inf'))
                    print(
                        f"页面加载或脚本超时 (请求 {i + 1}/{total_requests_per_thread}): URL={target_url_orig}, 方法={method}")
            except WebDriverException as e:  # This exception might be per-request or browser-fatal
                with lock:
                    failed_requests += 1
                    results.append(float('inf'))
                    print(
                        f"WebDriver 错误 (请求 {i + 1}/{total_requests_per_thread}): URL={target_url_orig}, 方法={method}, 错误='{str(e)[:200]}'")
                    # Consider breaking the loop for this thread if WebDriverException is severe
                    # For now, it will attempt the next request.
            except Exception as e:  # Catch any other unexpected errors per-request
                with lock:
                    failed_requests += 1
                    results.append(float('inf'))
                    print(
                        f"未知错误 (请求 {i + 1}/{total_requests_per_thread}): URL={target_url_orig}, 方法={method}, 错误='{str(e)[:200]}'")
        # End of the for loop for individual requests within a single browser session

    except Exception as e:  # Catch broader errors like browser failing to start
        with lock:
            # If the browser itself failed to start, all requests for this thread fail
            failed_requests += total_requests_per_thread  # Mark all as failed for this thread
            for _ in range(total_requests_per_thread):
                results.append(float('inf'))
            print(
                f"工作线程启动/严重错误: 错误='{str(e)[:200]}'. 此线程的所有 {total_requests_per_thread} 个请求已标记为失败。")
    finally:
        if driver:
            driver.quit()  # Quit the browser once all requests for this thread are done or a fatal error occurred


def main():
    print("Selenium 压力测试脚本")
    print("警告: 此脚本使用浏览器实例,资源消耗较大。建议使用较低的并发数。")
    print("请确保已安装 Selenium (pip install selenium) 和 ChromeDriver。")

    while True:
        target_url = input("请输入要测试的网站 URL: ").strip()
        if target_url:
            break
        else:
            print("URL 不能为空,请重新输入。")

    while True:
        http_method_str = input("请输入 HTTP 请求方法 (GET, POST, PUT, DELETE, 默认为 GET): ").strip().upper()
        if not http_method_str:
            http_method = "GET"
            break
        if http_method_str in ["GET", "POST", "PUT", "DELETE"]:
            http_method = http_method_str
            break
        else:
            print("无效的 HTTP 方法。请输入 GET, POST, PUT, 或 DELETE。")

    request_data_input = None  # This will store the string form of data
    if http_method in ["POST", "PUT"]:
        print(f"对于 {http_method} 请求:")
        while True:
            data_type_choice = input(
                "  请求数据是 JSON 格式 (j) 还是表单/查询字符串格式 (f)? (j/f, 默认为 j): ").strip().lower()
            if data_type_choice in ['j', 'json', '']:
                request_data_input = input("  请输入 JSON 格式的请求数据: ").strip()
                try:
                    json.loads(request_data_input)  # Validate
                    break
                except json.JSONDecodeError:
                    if not request_data_input:  # Allow empty JSON if user intends {} or ""
                        request_data_input = "{}"
                        print("  使用空JSON对象 {}。")
                        break
                    print("  输入的不是有效的 JSON 格式。")
            elif data_type_choice in ['f', 'form']:
                request_data_input = input(
                    "  请输入表单/查询字符串格式的数据 (例如: key1=value1&key2=value2): ").strip()
                break
            else:
                print("  无效选择。")
    elif http_method in ["GET", "DELETE"]:
        has_params = input(f"是否为 {http_method} 请求添加查询参数? (y/n, 默认为 n): ").strip().lower()
        if has_params in ['y', 'yes']:
            request_data_input = input(
                "  请输入查询参数 (例如: key1=value1&key2=value2 或 JSON 对象 {\"key\":\"value\"}): ").strip()
            if request_data_input.startswith("{") and request_data_input.endswith("}"):
                try:
                    json.loads(request_data_input)  # Validate if it looks like JSON
                except json.JSONDecodeError:
                    print("  警告: 输入看起来像JSON但无法解析,将作为普通字符串处理。")
            elif not request_data_input:
                request_data_input = None  # No params if empty

    chromedriver_path_input = input("请输入 ChromeDriver 的完整路径 (留空则尝试使用系统 PATH): ").strip()

    while True:
        try:
            num_threads_str = input("请输入并发线程数 (默认为 10): ").strip()
            if not num_threads_str:
                num_threads = 10
                break
            num_threads = int(num_threads_str)
            if num_threads > 0:
                break
            else:
                print("并发线程数必须大于 0。")
        except ValueError:
            print("请输入有效的数字作为并发线程数。")

    while True:
        try:
            total_requests_str = input("请输入总请求数 (默认为 100): ").strip()
            if not total_requests_str:
                total_requests = 100
                break
            total_requests = int(total_requests_str)
            if total_requests > 0:
                break
            else:
                print("总请求数必须大于 0。")
        except ValueError:
            print("请输入有效的数字作为总请求数。")

    custom_user_agent_input = input("请输入自定义 User-Agent (留空则使用浏览器默认): ").strip()

    if not target_url.startswith(('http://', 'https://')):
        target_url = 'http://' + target_url

    print(f"\n--- 测试配置 ---")
    print(f"目标 URL: {target_url}")
    print(f"HTTP 方法: {http_method}")
    if request_data_input:
        print(f"请求数据: {request_data_input[:200]}{'...' if len(request_data_input) > 200 else ''}")
    if chromedriver_path_input:
        print(f"ChromeDriver路径: {chromedriver_path_input}")
    else:
        print(f"ChromeDriver路径: 将尝试系统 PATH")
    if custom_user_agent_input:
        print(f"User-Agent: {custom_user_agent_input}")
    print(f"并发线程数 (浏览器实例数): {num_threads}")
    print(
        f"每个线程的请求数 (页面加载数): {total_requests // num_threads if num_threads > 0 else 0}")  # This interpretation changes
    print(f"总请求数 (总页面加载数): {total_requests}")
    print(f"--------------------")

    threads = []
    actual_num_threads = num_threads

    if total_requests == 0:
        print("总请求数为0,不执行测试。")
        return

    if actual_num_threads == 0 and total_requests > 0:  # if user entered 0 threads but >0 requests
        actual_num_threads = 1  # Default to 1 thread if requests are to be made
        print("并发线程数设为0,但总请求数大于0。将使用1个线程。")

    if actual_num_threads > total_requests:  # More threads than requests doesn't make sense
        actual_num_threads = total_requests
        print(f"并发线程数 ({num_threads}) 大于总请求数 ({total_requests})。调整线程数为 {actual_num_threads}。")

    # total_requests now means total page loads/actions
    # requests_per_thread means how many times each browser instance will perform the action

    if actual_num_threads == 0:  # Should not happen if total_requests > 0 due to above logic
        print("并发线程数为0,不执行测试。")
        return

    requests_per_thread_val = total_requests // actual_num_threads
    remaining_requests_val = total_requests % actual_num_threads

    print(f"将启动 {actual_num_threads} 个并发浏览器实例。")
    if requests_per_thread_val > 0:
        print(f"大多数浏览器实例将尝试加载页面 {requests_per_thread_val} 次。")

    start_overall_time = time.time()

    for i in range(actual_num_threads):
        req_count_for_this_thread = requests_per_thread_val + (1 if i < remaining_requests_val else 0)
        if req_count_for_this_thread == 0:
            continue

        thread_args = (
            target_url,
            http_method,
            request_data_input,
            custom_user_agent_input,
            chromedriver_path_input,
            req_count_for_this_thread
        )
        thread = threading.Thread(target=worker, args=thread_args)
        threads.append(thread)
        thread.start()

    for thread in threads:
        thread.join()

    end_overall_time = time.time()
    total_duration_overall = end_overall_time - start_overall_time

    print("\n--- 测试结果 ---")
    print(f"总耗时: {total_duration_overall:.4f} 秒")
    print(f"总请求数: {total_requests}")
    print(f"成功请求数: {successful_requests}")
    print(f"失败请求数: {failed_requests}")

    valid_results = [r for r in results if r != float('inf')]
    if valid_results:
        average_response_time = sum(valid_results) / len(valid_results)
        min_response_time = min(valid_results)
        max_response_time = max(valid_results)
        print(f"平均响应时间 (成功请求): {average_response_time:.4f} 秒")
        print(f"最快响应时间 (成功请求): {min_response_time:.4f} 秒")
        print(f"最慢响应时间 (成功请求): {max_response_time:.4f} 秒")
    else:
        print("没有成功的请求来计算响应时间。")


if __name__ == "__main__":
    main()

添加新评论