first commit

This commit is contained in:
vera
2026-01-20 17:37:43 +08:00
parent e8dd956fc2
commit 721c53fe87
6 changed files with 614 additions and 0 deletions

365
test_concurrent.py Normal file
View File

@ -0,0 +1,365 @@
import asyncio
import aiohttp
import time
import statistics
import csv
from typing import List, Dict
from datetime import datetime
# ---------------- 基本配置 ----------------
BASE_URL = "http://localhost:8880/generate_tts"
REQUEST_TEMPLATE = {
"text": "你这个新买的T-shirt好cool啊是哪个brand的周末我们去新开的mall里那家Starbucks喝杯coffee吧我听说他们的new season限定款Latte很OK。"
}
# REQUEST_TEMPLATE = {
# "text": "澳门在哪里啊",
# "cfg_value": "2.0",
# "inference_timesteps": "10",
# "normalize": "true",
# "denoise": "true",
# "prompt_text": "澳门有乜嘢好食嘅?"
# }
# # 音频路径(确保文件存在)
# PROMPT_WAV_PATH = "/home/verachen/Music/voice/2food.wav"
# headers 一般不必指定 multipart会自动设置
DEFAULT_HEADERS = {}
# ---------------- 请求逻辑 ----------------
async def tts_request(session: aiohttp.ClientSession, request_id: int) -> float:
"""执行一次请求,返回耗时(秒)"""
start = time.perf_counter()
try:
form = aiohttp.FormData()
for k, v in REQUEST_TEMPLATE.items():
form.add_field(k, v)
# form.add_field(
# "prompt_wav",
# open(PROMPT_WAV_PATH, "rb"),
# filename="2food.wav",
# content_type="audio/wav"
# )
async with session.post(BASE_URL, data=form, headers=DEFAULT_HEADERS) as resp:
await resp.read()
if resp.status != 200:
raise RuntimeError(f"HTTP {resp.status}")
except Exception as e:
print(f"[请求 {request_id}] 出错: {e}")
return -1
return time.perf_counter() - start
# ---------------- 并发测试核心 ----------------
async def benchmark(concurrency: int, total_requests: int) -> Dict:
"""在指定并发下发起多次请求,统计性能指标"""
timings = []
errors = 0
conn = aiohttp.TCPConnector(limit=0, force_close=False)
timeout = aiohttp.ClientTimeout(total=None)
async with aiohttp.ClientSession(connector=conn, timeout=timeout) as session:
sem = asyncio.Semaphore(concurrency)
async def worker(i):
nonlocal errors
async with sem:
t = await tts_request(session, i)
if t > 0:
timings.append(t)
else:
errors += 1
tasks = [asyncio.create_task(worker(i)) for i in range(total_requests)]
await asyncio.gather(*tasks)
succ = len(timings)
total_time = sum(timings)
result = {
"concurrency": concurrency,
"total_requests": total_requests,
"successes": succ,
"errors": errors,
}
if succ > 0:
result.update({
"min": min(timings),
"max": max(timings),
"avg": statistics.mean(timings),
"median": statistics.median(timings),
"qps": succ / total_time if total_time > 0 else 0
})
return result
# ---------------- CSV 写入 ----------------
def write_to_csv(filename: str, data: List[Dict]):
"""把测试结果写入 CSV 文件"""
fieldnames = ["concurrency", "total_requests", "successes", "errors",
"min", "max", "avg", "median", "qps"]
with open(filename, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for row in data:
writer.writerow(row)
print(f"\n✅ 测试结果已保存到: {filename}")
# ---------------- 主流程 ----------------
async def run_tests(concurrency_list: List[int], total_requests: dict):
print("开始压测接口:", BASE_URL)
results = []
for c in concurrency_list:
print(f"\n=== 并发数: {c} ===")
res = await benchmark(c, total_requests[c])
results.append(res)
print(f"请求总数: {res['total_requests']} | 成功: {res['successes']} | 失败: {res['errors']}")
if "avg" in res:
print(f"耗时 (s) → min={res['min']:.3f}, avg={res['avg']:.3f}, median={res['median']:.3f}, max={res['max']:.3f}")
print(f"近似 QPS: {res['qps']:.2f}")
else:
print("所有请求均失败")
# 保存结果
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
csv_filename = f"tts_benchmark_{timestamp}.csv"
write_to_csv(csv_filename, results)
if __name__ == "__main__":
# 你可以根据机器能力调整这两个参数
concurrency_list = [1, 5, 10, 20, 50, 100, 150, 200] # 并发测试范围
total_requests = {
1: 50,
5: 100,
10: 200,
20: 400,
50: 1000,
100: 2000,
150: 3000,
200: 4000,
300: 6000,
500: 10000
} # 每个并发等级下的请求数
asyncio.run(run_tests(concurrency_list, total_requests))
# import asyncio
# import aiohttp
# import time
# import statistics
# from typing import List, Dict
# # 新的接口 URL
# BASE_URL = "http://127.0.0.1:8880/generate_tts"
# # # 固定参数(不包括音频)
# # FORM_PARAMS = {
# # "text": "澳门在哪里啊",
# # "cfg_value": "2.0",
# # "inference_timesteps": "10",
# # "normalize": "true",
# # "denoise": "true",
# # "prompt_text": "澳门有乜嘢好食嘅?"
# # }
# FORM_PARAMS = {
# "text": "澳门在哪里啊"
# }
# # 音频路径(确保文件存在)
# # PROMPT_WAV_PATH = "/home/verachen/Music/voice/2food.wav"
# # headers 一般不必指定 multipart会自动设置
# DEFAULT_HEADERS = {}
# async def tts_request(session: aiohttp.ClientSession) -> float:
# """
# 发起一次 multipart/form-data 格式的 TTS 请求,返回耗时。
# """
# start = time.perf_counter()
# form = aiohttp.FormData()
# for k, v in FORM_PARAMS.items():
# form.add_field(k, v)
# # form.add_field(
# # "prompt_wav",
# # open(PROMPT_WAV_PATH, "rb"),
# # filename="2food.wav",
# # content_type="audio/wav"
# # )
# async with session.post(BASE_URL, data=form, headers=DEFAULT_HEADERS) as resp:
# data = await resp.read()
# if resp.status != 200:
# raise RuntimeError(f"HTTP {resp.status}, body: {data[:200]!r}")
# elapsed = time.perf_counter() - start
# return elapsed
# async def benchmark(concurrency: int, total_requests: int) -> Dict:
# timings: List[float] = []
# errors = 0
# conn = aiohttp.TCPConnector(limit=0)
# timeout = aiohttp.ClientTimeout(total=None)
# async with aiohttp.ClientSession(connector=conn, timeout=timeout) as session:
# sem = asyncio.Semaphore(concurrency)
# async def worker():
# nonlocal errors
# async with sem:
# try:
# t = await tts_request(session)
# timings.append(t)
# except Exception as e:
# errors += 1
# print("请求出错:", e)
# tasks = [asyncio.create_task(worker()) for _ in range(total_requests)]
# await asyncio.gather(*tasks)
# succ = len(timings)
# total_time = sum(timings) if timings else 0.0
# result = {
# "concurrency": concurrency,
# "total_requests": total_requests,
# "successes": succ,
# "errors": errors,
# }
# if succ > 0:
# result.update({
# "min": min(timings),
# "max": max(timings),
# "avg": statistics.mean(timings),
# "median": statistics.median(timings),
# "qps": succ / total_time if total_time > 0 else None
# })
# return result
# async def run_tests(concurrency_list: List[int], total_requests: int):
# print("开始并发测试,目标接口:", BASE_URL)
# for c in concurrency_list:
# print(f"\n--- 并发 = {c} ---")
# res = await benchmark(c, total_requests)
# print("总请求:", res["total_requests"])
# print("成功:", res["successes"], "失败:", res["errors"])
# if "avg" in res:
# print(f"耗时 min / avg / median / max = "
# f"{res['min']:.3f} / {res['avg']:.3f} / {res['median']:.3f} / {res['max']:.3f} 秒")
# print(f"近似 QPS = {res['qps']:.2f}")
# else:
# print("所有请求都失败了")
# if __name__ == "__main__":
# # 你可以调整这些参数
# concurrency_list = [200] # 要试的并发数列表
# total_requests = 100 # 每个并发等级下总请求数
# asyncio.run(run_tests(concurrency_list, total_requests))
# import multiprocessing as mp
# import requests
# import time
# import statistics
# from typing import Dict, List
# BASE_URL = "http://127.0.0.1:8880/generate_tts"
# FORM_PARAMS = {
# "text": "澳门在哪里啊",
# "cfg_value": "2.0",
# "inference_timesteps": "10",
# "normalize": "true",
# "denoise": "true",
# "prompt_text": "澳门有乜嘢好食嘅?"
# }
# PROMPT_WAV_PATH = "/home/verachen/Music/voice/2food.wav"
# def single_request() -> float:
# """用 requests 同步发一次 multipart/form-data 请求,返回耗时(秒)或抛异常。"""
# files = {
# "prompt_wav": ("2food.wav", open(PROMPT_WAV_PATH, "rb"), "audio/wav")
# }
# data = FORM_PARAMS.copy()
# start = time.perf_counter()
# resp = requests.post(BASE_URL, data=data, files=files, timeout=60)
# elapsed = time.perf_counter() - start
# if resp.status_code != 200:
# raise RuntimeError(f"HTTP {resp.status_code}, body: {resp.text[:200]!r}")
# return elapsed
# def worker_task(num_requests: int, return_list: mp.Manager().list, err_list: mp.Manager().list):
# """子进程做 num_requests 次请求,将各次耗时记录到 return_list共享 list错误次数记录到 err_list。"""
# for _ in range(num_requests):
# try:
# t = single_request()
# return_list.append(t)
# except Exception as e:
# err_list.append(str(e))
# def run_multiproc(concurrency: int, total_requests: int) -> Dict:
# """
# 用多个进程模拟并发:
# - 每个进程发 total_requestsconcurrency 次请求(向下取整或略分配)
# - 或者简单地让每个进程跑 total_requests 次(更激进)
# """
# manager = mp.Manager()
# times = manager.list()
# errs = manager.list()
# procs = []
# # 任务分配:每个子进程跑一部分请求
# per = total_requests // concurrency
# if per < 1:
# per = 1
# for i in range(concurrency):
# p = mp.Process(target=worker_task, args=(per, times, errs))
# p.start()
# procs.append(p)
# for p in procs:
# p.join()
# timings = list(times)
# errors = list(errs)
# succ = len(timings)
# total_time = sum(timings) if timings else 0.0
# ret = {
# "concurrency": concurrency,
# "total_requests": total_requests,
# "successes": succ,
# "errors": len(errors),
# }
# if succ > 0:
# ret.update({
# "min": min(timings),
# "max": max(timings),
# "avg": statistics.mean(timings),
# "median": statistics.median(timings),
# "qps": succ / total_time if total_time > 0 else None
# })
# return ret
# if __name__ == "__main__":
# concurrency = 4
# total_requests = 40
# print("开始多进程并发测试")
# result = run_multiproc(concurrency, total_requests)
# print(result)