Files
VoxCPM/test_concurrent.py
2026-01-20 17:37:43 +08:00

366 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import aiohttp
import time
import statistics
import csv
from typing import List, Dict
from datetime import datetime
# ---------------- 基本配置 ----------------
BASE_URL = "http://localhost:8880/generate_tts"
REQUEST_TEMPLATE = {
"text": "你这个新买的T-shirt好cool啊是哪个brand的周末我们去新开的mall里那家Starbucks喝杯coffee吧我听说他们的new season限定款Latte很OK。"
}
# REQUEST_TEMPLATE = {
# "text": "澳门在哪里啊",
# "cfg_value": "2.0",
# "inference_timesteps": "10",
# "normalize": "true",
# "denoise": "true",
# "prompt_text": "澳门有乜嘢好食嘅?"
# }
# # 音频路径(确保文件存在)
# PROMPT_WAV_PATH = "/home/verachen/Music/voice/2food.wav"
# headers 一般不必指定 multipart会自动设置
DEFAULT_HEADERS = {}
# ---------------- 请求逻辑 ----------------
async def tts_request(session: aiohttp.ClientSession, request_id: int) -> float:
"""执行一次请求,返回耗时(秒)"""
start = time.perf_counter()
try:
form = aiohttp.FormData()
for k, v in REQUEST_TEMPLATE.items():
form.add_field(k, v)
# form.add_field(
# "prompt_wav",
# open(PROMPT_WAV_PATH, "rb"),
# filename="2food.wav",
# content_type="audio/wav"
# )
async with session.post(BASE_URL, data=form, headers=DEFAULT_HEADERS) as resp:
await resp.read()
if resp.status != 200:
raise RuntimeError(f"HTTP {resp.status}")
except Exception as e:
print(f"[请求 {request_id}] 出错: {e}")
return -1
return time.perf_counter() - start
# ---------------- 并发测试核心 ----------------
async def benchmark(concurrency: int, total_requests: int) -> Dict:
"""在指定并发下发起多次请求,统计性能指标"""
timings = []
errors = 0
conn = aiohttp.TCPConnector(limit=0, force_close=False)
timeout = aiohttp.ClientTimeout(total=None)
async with aiohttp.ClientSession(connector=conn, timeout=timeout) as session:
sem = asyncio.Semaphore(concurrency)
async def worker(i):
nonlocal errors
async with sem:
t = await tts_request(session, i)
if t > 0:
timings.append(t)
else:
errors += 1
tasks = [asyncio.create_task(worker(i)) for i in range(total_requests)]
await asyncio.gather(*tasks)
succ = len(timings)
total_time = sum(timings)
result = {
"concurrency": concurrency,
"total_requests": total_requests,
"successes": succ,
"errors": errors,
}
if succ > 0:
result.update({
"min": min(timings),
"max": max(timings),
"avg": statistics.mean(timings),
"median": statistics.median(timings),
"qps": succ / total_time if total_time > 0 else 0
})
return result
# ---------------- CSV 写入 ----------------
def write_to_csv(filename: str, data: List[Dict]):
"""把测试结果写入 CSV 文件"""
fieldnames = ["concurrency", "total_requests", "successes", "errors",
"min", "max", "avg", "median", "qps"]
with open(filename, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for row in data:
writer.writerow(row)
print(f"\n✅ 测试结果已保存到: {filename}")
# ---------------- 主流程 ----------------
async def run_tests(concurrency_list: List[int], total_requests: dict):
print("开始压测接口:", BASE_URL)
results = []
for c in concurrency_list:
print(f"\n=== 并发数: {c} ===")
res = await benchmark(c, total_requests[c])
results.append(res)
print(f"请求总数: {res['total_requests']} | 成功: {res['successes']} | 失败: {res['errors']}")
if "avg" in res:
print(f"耗时 (s) → min={res['min']:.3f}, avg={res['avg']:.3f}, median={res['median']:.3f}, max={res['max']:.3f}")
print(f"近似 QPS: {res['qps']:.2f}")
else:
print("所有请求均失败")
# 保存结果
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
csv_filename = f"tts_benchmark_{timestamp}.csv"
write_to_csv(csv_filename, results)
if __name__ == "__main__":
# 你可以根据机器能力调整这两个参数
concurrency_list = [1, 5, 10, 20, 50, 100, 150, 200] # 并发测试范围
total_requests = {
1: 50,
5: 100,
10: 200,
20: 400,
50: 1000,
100: 2000,
150: 3000,
200: 4000,
300: 6000,
500: 10000
} # 每个并发等级下的请求数
asyncio.run(run_tests(concurrency_list, total_requests))
# import asyncio
# import aiohttp
# import time
# import statistics
# from typing import List, Dict
# # 新的接口 URL
# BASE_URL = "http://127.0.0.1:8880/generate_tts"
# # # 固定参数(不包括音频)
# # FORM_PARAMS = {
# # "text": "澳门在哪里啊",
# # "cfg_value": "2.0",
# # "inference_timesteps": "10",
# # "normalize": "true",
# # "denoise": "true",
# # "prompt_text": "澳门有乜嘢好食嘅?"
# # }
# FORM_PARAMS = {
# "text": "澳门在哪里啊"
# }
# # 音频路径(确保文件存在)
# # PROMPT_WAV_PATH = "/home/verachen/Music/voice/2food.wav"
# # headers 一般不必指定 multipart会自动设置
# DEFAULT_HEADERS = {}
# async def tts_request(session: aiohttp.ClientSession) -> float:
# """
# 发起一次 multipart/form-data 格式的 TTS 请求,返回耗时。
# """
# start = time.perf_counter()
# form = aiohttp.FormData()
# for k, v in FORM_PARAMS.items():
# form.add_field(k, v)
# # form.add_field(
# # "prompt_wav",
# # open(PROMPT_WAV_PATH, "rb"),
# # filename="2food.wav",
# # content_type="audio/wav"
# # )
# async with session.post(BASE_URL, data=form, headers=DEFAULT_HEADERS) as resp:
# data = await resp.read()
# if resp.status != 200:
# raise RuntimeError(f"HTTP {resp.status}, body: {data[:200]!r}")
# elapsed = time.perf_counter() - start
# return elapsed
# async def benchmark(concurrency: int, total_requests: int) -> Dict:
# timings: List[float] = []
# errors = 0
# conn = aiohttp.TCPConnector(limit=0)
# timeout = aiohttp.ClientTimeout(total=None)
# async with aiohttp.ClientSession(connector=conn, timeout=timeout) as session:
# sem = asyncio.Semaphore(concurrency)
# async def worker():
# nonlocal errors
# async with sem:
# try:
# t = await tts_request(session)
# timings.append(t)
# except Exception as e:
# errors += 1
# print("请求出错:", e)
# tasks = [asyncio.create_task(worker()) for _ in range(total_requests)]
# await asyncio.gather(*tasks)
# succ = len(timings)
# total_time = sum(timings) if timings else 0.0
# result = {
# "concurrency": concurrency,
# "total_requests": total_requests,
# "successes": succ,
# "errors": errors,
# }
# if succ > 0:
# result.update({
# "min": min(timings),
# "max": max(timings),
# "avg": statistics.mean(timings),
# "median": statistics.median(timings),
# "qps": succ / total_time if total_time > 0 else None
# })
# return result
# async def run_tests(concurrency_list: List[int], total_requests: int):
# print("开始并发测试,目标接口:", BASE_URL)
# for c in concurrency_list:
# print(f"\n--- 并发 = {c} ---")
# res = await benchmark(c, total_requests)
# print("总请求:", res["total_requests"])
# print("成功:", res["successes"], "失败:", res["errors"])
# if "avg" in res:
# print(f"耗时 min / avg / median / max = "
# f"{res['min']:.3f} / {res['avg']:.3f} / {res['median']:.3f} / {res['max']:.3f} 秒")
# print(f"近似 QPS = {res['qps']:.2f}")
# else:
# print("所有请求都失败了")
# if __name__ == "__main__":
# # 你可以调整这些参数
# concurrency_list = [200] # 要试的并发数列表
# total_requests = 100 # 每个并发等级下总请求数
# asyncio.run(run_tests(concurrency_list, total_requests))
# import multiprocessing as mp
# import requests
# import time
# import statistics
# from typing import Dict, List
# BASE_URL = "http://127.0.0.1:8880/generate_tts"
# FORM_PARAMS = {
# "text": "澳门在哪里啊",
# "cfg_value": "2.0",
# "inference_timesteps": "10",
# "normalize": "true",
# "denoise": "true",
# "prompt_text": "澳门有乜嘢好食嘅?"
# }
# PROMPT_WAV_PATH = "/home/verachen/Music/voice/2food.wav"
# def single_request() -> float:
# """用 requests 同步发一次 multipart/form-data 请求,返回耗时(秒)或抛异常。"""
# files = {
# "prompt_wav": ("2food.wav", open(PROMPT_WAV_PATH, "rb"), "audio/wav")
# }
# data = FORM_PARAMS.copy()
# start = time.perf_counter()
# resp = requests.post(BASE_URL, data=data, files=files, timeout=60)
# elapsed = time.perf_counter() - start
# if resp.status_code != 200:
# raise RuntimeError(f"HTTP {resp.status_code}, body: {resp.text[:200]!r}")
# return elapsed
# def worker_task(num_requests: int, return_list: mp.Manager().list, err_list: mp.Manager().list):
# """子进程做 num_requests 次请求,将各次耗时记录到 return_list共享 list错误次数记录到 err_list。"""
# for _ in range(num_requests):
# try:
# t = single_request()
# return_list.append(t)
# except Exception as e:
# err_list.append(str(e))
# def run_multiproc(concurrency: int, total_requests: int) -> Dict:
# """
# 用多个进程模拟并发:
# - 每个进程发 total_requestsconcurrency 次请求(向下取整或略分配)
# - 或者简单地让每个进程跑 total_requests 次(更激进)
# """
# manager = mp.Manager()
# times = manager.list()
# errs = manager.list()
# procs = []
# # 任务分配:每个子进程跑一部分请求
# per = total_requests // concurrency
# if per < 1:
# per = 1
# for i in range(concurrency):
# p = mp.Process(target=worker_task, args=(per, times, errs))
# p.start()
# procs.append(p)
# for p in procs:
# p.join()
# timings = list(times)
# errors = list(errs)
# succ = len(timings)
# total_time = sum(timings) if timings else 0.0
# ret = {
# "concurrency": concurrency,
# "total_requests": total_requests,
# "successes": succ,
# "errors": len(errors),
# }
# if succ > 0:
# ret.update({
# "min": min(timings),
# "max": max(timings),
# "avg": statistics.mean(timings),
# "median": statistics.median(timings),
# "qps": succ / total_time if total_time > 0 else None
# })
# return ret
# if __name__ == "__main__":
# concurrency = 4
# total_requests = 40
# print("开始多进程并发测试")
# result = run_multiproc(concurrency, total_requests)
# print(result)