feat: streaming api
All checks were successful
Build container / build-docker (push) Successful in 17m34s
All checks were successful
Build container / build-docker (push) Successful in 17m34s
This commit is contained in:
31
.github/workflows/docker-build.yml
vendored
Normal file
31
.github/workflows/docker-build.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
name: Build container
|
||||||
|
env:
|
||||||
|
VERSION: 0.0.1
|
||||||
|
REGISTRY: https://harbor.bwgdi.com
|
||||||
|
REGISTRY_NAME: harbor.bwgdi.com
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
workflow_dispatch:
|
||||||
|
jobs:
|
||||||
|
build-docker:
|
||||||
|
runs-on: builder-ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Login to Docker Hub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ${{ env.REGISTRY }}
|
||||||
|
username: ${{ secrets.BWGDI_NAME }}
|
||||||
|
password: ${{ secrets.BWGDI_TOKEN }}
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v2
|
||||||
|
- name: Build and push
|
||||||
|
uses: docker/build-push-action@v4
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: ./Dockerfile
|
||||||
|
push: true
|
||||||
|
tags: ${{ env.REGISTRY_NAME }}/library/qwen3-asr:${{ env.VERSION }}
|
||||||
14
Dockerfile
Normal file
14
Dockerfile
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y ffmpeg && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
COPY . .
|
||||||
|
RUN uv sync
|
||||||
|
|
||||||
|
ENV ASR_MODEL_PATH="Qwen/Qwen3-ASR-1.7B"
|
||||||
|
|
||||||
|
EXPOSE 5000
|
||||||
|
CMD ["sh", "-c", "uv run examples/api_unified_fastapi.py --asr-model-path $ASR_MODEL_PATH"]
|
||||||
165
README_BW.md
Normal file
165
README_BW.md
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
# Qwen3-ASR
|
||||||
|
|
||||||
|
https://github.com/QwenLM/Qwen3-ASR
|
||||||
|
|
||||||
|
## 📦 Version History
|
||||||
|
|
||||||
|
| Version | Date | Summary |
|
||||||
|
|---------|------------|---------------------------------|
|
||||||
|
| 0.0.1 | 2026-04-22 | Initial version |
|
||||||
|
|
||||||
|
### 🔄 Version Details
|
||||||
|
|
||||||
|
#### 🆕 0.0.1 – *2026-04-22*
|
||||||
|
- ✅ **Core Features**
|
||||||
|
- Initial Qwen3-ASR integration
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker pull harbor.bwgdi.com/library/qwen3asr:0.0.1
|
||||||
|
|
||||||
|
# Run with custom model path
|
||||||
|
# -e ASR_MODEL_PATH: Model name or local path inside container
|
||||||
|
docker run -d --restart always -p 8000:8000 --gpus all \
|
||||||
|
-e ASR_MODEL_PATH="Qwen/Qwen3-ASR-1.7B" \
|
||||||
|
--mount type=bind,source=/path/to/your/models,target=/models \
|
||||||
|
harbor.bwgdi.com/library/qwen3asr:0.0.3
|
||||||
|
```
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
|
||||||
|
## Non-streaming (HTTP POST)
|
||||||
|
Transcribe an entire audio file.
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/asr/transcribe \
|
||||||
|
-F "file=@audio.wav" \
|
||||||
|
-F "language=Chinese"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Streaming (WebSocket)
|
||||||
|
Real-time incremental transcription.
|
||||||
|
- **URL**: `ws://localhost:8000/asr/stream`
|
||||||
|
- **Protocol**:
|
||||||
|
- Client sends `bytes`: float32 PCM 16kHz audio chunks.
|
||||||
|
- Client sends `text`: `{"command": "finish"}` to stop.
|
||||||
|
- Server sends `text`: `{"session_id": ..., "language": ..., "text": ..., "is_final": bool}`
|
||||||
|
|
||||||
|
Example using Python `websockets`:
|
||||||
|
```python
|
||||||
|
# coding=utf-8
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def stream_audio_to_api(uri: str, audio_path: str, chunk_size_ms: int = 500):
|
||||||
|
"""
|
||||||
|
Load audio and stream it in chunks to the ASR WebSocket API.
|
||||||
|
"""
|
||||||
|
logger.info(f"Loading audio from {audio_path}...")
|
||||||
|
|
||||||
|
# Load audio data
|
||||||
|
if audio_path.startswith("http"):
|
||||||
|
# Download from URL
|
||||||
|
req = urllib.request.Request(audio_path, headers={"User-Agent": "Mozilla/5.0"})
|
||||||
|
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||||
|
audio_bytes = resp.read()
|
||||||
|
f = io.BytesIO(audio_bytes)
|
||||||
|
else:
|
||||||
|
# Load local file
|
||||||
|
f = audio_path
|
||||||
|
|
||||||
|
# Read audio as Float32
|
||||||
|
wav, sr = sf.read(f, dtype="float32", always_2d=False)
|
||||||
|
|
||||||
|
# Simple resample to 16k if needed (for better accuracy)
|
||||||
|
if sr != 16000:
|
||||||
|
logger.warning(f"Audio sample rate is {sr}, resampling to 16000...")
|
||||||
|
dur = wav.shape[0] / float(sr)
|
||||||
|
n16 = int(round(dur * 16000))
|
||||||
|
x_old = np.linspace(0.0, dur, num=wav.shape[0], endpoint=False)
|
||||||
|
x_new = np.linspace(0.0, dur, num=n16, endpoint=False)
|
||||||
|
wav = np.interp(x_new, x_old, wav).astype(np.float32)
|
||||||
|
sr = 16000
|
||||||
|
|
||||||
|
# Calculate samples per chunk
|
||||||
|
chunk_samples = int(sr * chunk_size_ms / 1000)
|
||||||
|
|
||||||
|
logger.info(f"Connecting to WebSocket at {uri}...")
|
||||||
|
try:
|
||||||
|
async with websockets.connect(uri) as websocket:
|
||||||
|
logger.info("Connected. Streaming audio...")
|
||||||
|
|
||||||
|
pos = 0
|
||||||
|
call_id = 0
|
||||||
|
while pos < len(wav):
|
||||||
|
chunk = wav[pos : pos + chunk_samples]
|
||||||
|
pos += len(chunk)
|
||||||
|
call_id += 1
|
||||||
|
|
||||||
|
# Send binary Float32 data
|
||||||
|
await websocket.send(chunk.tobytes())
|
||||||
|
|
||||||
|
# Wait for immediate response (intermediate result)
|
||||||
|
try:
|
||||||
|
response = await asyncio.wait_for(websocket.recv(), timeout=2.0)
|
||||||
|
result = json.loads(response)
|
||||||
|
if "error" in result:
|
||||||
|
logger.error(f"API Error: {result['error']}")
|
||||||
|
return
|
||||||
|
|
||||||
|
lang = result.get("language", "unknown")
|
||||||
|
text = result.get("text", "")
|
||||||
|
print(f"[Chunk {call_id:03d}] Lang: {lang:7s} | Text: {text}")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"Timeout waiting for response on chunk {call_id}")
|
||||||
|
|
||||||
|
# Optional: simulate real-time performance
|
||||||
|
# await asyncio.sleep(chunk_size_ms / 1000)
|
||||||
|
|
||||||
|
# Send finish command
|
||||||
|
logger.info("Finished streaming audio. Sending 'finish' command...")
|
||||||
|
await websocket.send(json.dumps({"command": "finish"}))
|
||||||
|
|
||||||
|
# Wait for final response
|
||||||
|
try:
|
||||||
|
final_response = await asyncio.wait_for(websocket.recv(), timeout=5.0)
|
||||||
|
final_result = json.loads(final_response)
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print("FINAL RESULT:")
|
||||||
|
print(f"Language: {final_result.get('language')}")
|
||||||
|
print(f"Text: {final_result.get('text')}")
|
||||||
|
print("="*50)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("Timeout waiting for final response")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket Error: {e}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Qwen3-ASR Streaming API Client Test")
|
||||||
|
parser.add_argument("--url", default="ws://localhost:8000/asr/stream", help="WebSocket API URI")
|
||||||
|
parser.add_argument("--audio", default="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_en.wav",
|
||||||
|
help="Path or URL to audio file")
|
||||||
|
parser.add_argument("--chunk-ms", type=int, default=1000, help="Chunk size in milliseconds")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
asyncio.run(stream_audio_to_api(args.url, args.audio, args.chunk_ms))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
```
|
||||||
128
examples/api_streaming_fastapi.py
Normal file
128
examples/api_streaming_fastapi.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
from qwen_asr import Qwen3ASRModel
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
app = FastAPI(title="Qwen3-ASR Streaming API")
|
||||||
|
|
||||||
|
# Global ASR model instance
|
||||||
|
asr_model: Optional[Qwen3ASRModel] = None
|
||||||
|
server_args = None
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
global asr_model
|
||||||
|
logger.info(f"Loading ASR model from {server_args.asr_model_path}...")
|
||||||
|
# Using vLLM backend for streaming
|
||||||
|
asr_model = Qwen3ASRModel.LLM(
|
||||||
|
model=server_args.asr_model_path,
|
||||||
|
gpu_memory_utilization=server_args.gpu_memory_utilization,
|
||||||
|
max_new_tokens=server_args.max_new_tokens,
|
||||||
|
)
|
||||||
|
logger.info("Model loaded successfully.")
|
||||||
|
|
||||||
|
@app.websocket("/asr/stream")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
session_id = uuid.uuid4().hex
|
||||||
|
logger.info(f"New session started: {session_id}")
|
||||||
|
|
||||||
|
# Initialize streaming state for this session
|
||||||
|
state = asr_model.init_streaming_state(
|
||||||
|
unfixed_chunk_num=server_args.unfixed_chunk_num,
|
||||||
|
unfixed_token_num=server_args.unfixed_token_num,
|
||||||
|
chunk_size_sec=server_args.chunk_size_sec,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Receive message from client
|
||||||
|
message = await websocket.receive()
|
||||||
|
|
||||||
|
if "bytes" in message:
|
||||||
|
# Binary audio data (Float32, 16kHz)
|
||||||
|
raw_bytes = message["bytes"]
|
||||||
|
if len(raw_bytes) % 4 != 0:
|
||||||
|
await websocket.send_json({"error": "Data length must be multiple of 4 bytes (Float32)"})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert bytes to numpy array
|
||||||
|
wav = np.frombuffer(raw_bytes, dtype=np.float32)
|
||||||
|
|
||||||
|
# Perform streaming transcription
|
||||||
|
asr_model.streaming_transcribe(wav, state)
|
||||||
|
|
||||||
|
# Send back current intermediate transcription
|
||||||
|
await websocket.send_json({
|
||||||
|
"session_id": session_id,
|
||||||
|
"language": getattr(state, "language", "") or "",
|
||||||
|
"text": getattr(state, "text", "") or "",
|
||||||
|
"is_final": False
|
||||||
|
})
|
||||||
|
|
||||||
|
elif "text" in message:
|
||||||
|
# Command message
|
||||||
|
try:
|
||||||
|
text_data = json.loads(message["text"])
|
||||||
|
if text_data.get("command") == "finish":
|
||||||
|
logger.info(f"Finish command received for session: {session_id}")
|
||||||
|
break
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Received invalid JSON text: {message['text']}")
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info(f"Client disconnected: {session_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in session {session_id}: {e}", exc_info=True)
|
||||||
|
try:
|
||||||
|
print(e)
|
||||||
|
await websocket.send_json({"error": str(e)})
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
# Finish transcription and send final results
|
||||||
|
try:
|
||||||
|
asr_model.finish_streaming_transcribe(state)
|
||||||
|
await websocket.send_json({
|
||||||
|
"session_id": session_id,
|
||||||
|
"language": getattr(state, "language", "") or "",
|
||||||
|
"text": getattr(state, "text", "") or "",
|
||||||
|
"is_final": True
|
||||||
|
})
|
||||||
|
logger.info(f"Sent final result for session: {session_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error while finishing session {session_id}: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await websocket.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
logger.info(f"Session closed: {session_id}")
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
p = argparse.ArgumentParser(description="Qwen3-ASR Streaming API (vLLM backend)")
|
||||||
|
p.add_argument("--asr-model-path", default="Qwen/Qwen3-ASR-1.7B", help="Model name or local path")
|
||||||
|
p.add_argument("--host", default="0.0.0.0", help="Bind host")
|
||||||
|
p.add_argument("--port", type=int, default=8000, help="Bind port")
|
||||||
|
p.add_argument("--gpu-memory-utilization", type=float, default=0.8, help="vLLM GPU memory utilization")
|
||||||
|
p.add_argument("--max-new-tokens", type=int, default=32, help="Max new tokens to generate per streaming call. Small value is recommended for low latency.")
|
||||||
|
p.add_argument("--unfixed-chunk-num", type=int, default=4, help="Number of unfixed chunks in streaming")
|
||||||
|
p.add_argument("--unfixed-token-num", type=int, default=5, help="Number of unfixed tokens in streaming")
|
||||||
|
p.add_argument("--chunk-size-sec", type=float, default=1.0, help="Size of each chunk in seconds")
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
server_args = parse_args()
|
||||||
|
# Note: Use uvicorn to run the FastAPI app
|
||||||
|
uvicorn.run(app, host=server_args.host, port=server_args.port)
|
||||||
188
examples/api_unified_fastapi.py
Normal file
188
examples/api_unified_fastapi.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
Unified Qwen3-ASR API Server
|
||||||
|
============================
|
||||||
|
模型只加载一次,同时提供:
|
||||||
|
- POST /asr/transcribe (非流式,整段音频转写)
|
||||||
|
- WS /asr/stream (流式,实时增量转写)
|
||||||
|
|
||||||
|
启动示例:
|
||||||
|
uv run examples/api_unified_fastapi.py --asr-model-path Qwen/Qwen3-ASR-1.7B
|
||||||
|
|
||||||
|
非流式调用示例(Python):
|
||||||
|
import requests
|
||||||
|
with open("audio.wav", "rb") as f:
|
||||||
|
resp = requests.post("http://localhost:8000/asr/transcribe",
|
||||||
|
files={"file": ("audio.wav", f, "audio/wav")},
|
||||||
|
data={"context": "", "language": ""})
|
||||||
|
print(resp.json())
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, File, Form, UploadFile, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from qwen_asr import Qwen3ASRModel
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
app = FastAPI(title="Qwen3-ASR Unified API")
|
||||||
|
|
||||||
|
# ── 全局单例 ──────────────────────────────────────────────────────────────────
|
||||||
|
asr_model: Optional[Qwen3ASRModel] = None
|
||||||
|
server_args = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
global asr_model
|
||||||
|
logger.info(f"Loading ASR model from {server_args.asr_model_path} ...")
|
||||||
|
asr_model = Qwen3ASRModel.LLM(
|
||||||
|
model=server_args.asr_model_path,
|
||||||
|
gpu_memory_utilization=server_args.gpu_memory_utilization,
|
||||||
|
max_new_tokens=server_args.max_new_tokens,
|
||||||
|
)
|
||||||
|
logger.info("Model loaded successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
# ── 非流式端点(HTTP POST multipart)─────────────────────────────────────────
|
||||||
|
|
||||||
|
@app.post("/asr/transcribe")
|
||||||
|
async def transcribe_endpoint(
|
||||||
|
file: UploadFile = File(..., description="音频文件(wav/mp3/flac 等 soundfile 支持的格式)"),
|
||||||
|
context: str = Form(default="", description="可选上下文"),
|
||||||
|
language: str = Form(default="", description="可选强制语言,如 Chinese / English"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
非流式整段转写。
|
||||||
|
以 multipart/form-data 上传音频文件,返回最终转写文本。
|
||||||
|
|
||||||
|
curl 示例:
|
||||||
|
curl -X POST http://localhost:8000/asr/transcribe \\
|
||||||
|
-F "file=@audio.wav" -F "context=" -F "language="
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
raw = await file.read()
|
||||||
|
with io.BytesIO(raw) as buf:
|
||||||
|
wav, sr = sf.read(buf, dtype="float32", always_2d=False)
|
||||||
|
wav = np.asarray(wav, dtype=np.float32)
|
||||||
|
|
||||||
|
results = asr_model.transcribe(
|
||||||
|
audio=(wav, sr),
|
||||||
|
context=context,
|
||||||
|
language=language.strip() or None,
|
||||||
|
)
|
||||||
|
r = results[0]
|
||||||
|
return {"language": r.language, "text": r.text}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Transcribe error: {e}", exc_info=True)
|
||||||
|
return JSONResponse(status_code=500, content={"error": str(e)})
|
||||||
|
|
||||||
|
|
||||||
|
# ── 流式端点(WebSocket)──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@app.websocket("/asr/stream")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
"""
|
||||||
|
流式增量转写。
|
||||||
|
协议:
|
||||||
|
客户端 → bytes : float32 PCM 16kHz 音频块
|
||||||
|
客户端 → text : {"command": "finish"} 结束会话
|
||||||
|
服务端 → text : {"session_id": ..., "language": ..., "text": ..., "is_final": bool}
|
||||||
|
"""
|
||||||
|
await websocket.accept()
|
||||||
|
session_id = uuid.uuid4().hex
|
||||||
|
logger.info(f"Stream session started: {session_id}")
|
||||||
|
|
||||||
|
state = asr_model.init_streaming_state(
|
||||||
|
unfixed_chunk_num=server_args.unfixed_chunk_num,
|
||||||
|
unfixed_token_num=server_args.unfixed_token_num,
|
||||||
|
chunk_size_sec=server_args.chunk_size_sec,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await websocket.receive()
|
||||||
|
|
||||||
|
if "bytes" in message:
|
||||||
|
raw = message["bytes"]
|
||||||
|
if len(raw) % 4 != 0:
|
||||||
|
await websocket.send_json({"error": "Data length must be multiple of 4 bytes (float32)"})
|
||||||
|
continue
|
||||||
|
|
||||||
|
wav = np.frombuffer(raw, dtype=np.float32)
|
||||||
|
asr_model.streaming_transcribe(wav, state)
|
||||||
|
|
||||||
|
await websocket.send_json({
|
||||||
|
"session_id": session_id,
|
||||||
|
"language": state.language or "",
|
||||||
|
"text": state.text or "",
|
||||||
|
"is_final": False,
|
||||||
|
})
|
||||||
|
|
||||||
|
elif "text" in message:
|
||||||
|
try:
|
||||||
|
cmd = json.loads(message["text"])
|
||||||
|
if cmd.get("command") == "finish":
|
||||||
|
logger.info(f"Finish command received: {session_id}")
|
||||||
|
break
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Invalid JSON: {message['text']}")
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info(f"Client disconnected: {session_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in session {session_id}: {e}", exc_info=True)
|
||||||
|
try:
|
||||||
|
await websocket.send_json({"error": str(e)})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
asr_model.finish_streaming_transcribe(state)
|
||||||
|
await websocket.send_json({
|
||||||
|
"session_id": session_id,
|
||||||
|
"language": state.language or "",
|
||||||
|
"text": state.text or "",
|
||||||
|
"is_final": True,
|
||||||
|
})
|
||||||
|
logger.info(f"Final result sent: {session_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error finishing session {session_id}: {e}")
|
||||||
|
try:
|
||||||
|
await websocket.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
logger.info(f"Session closed: {session_id}")
|
||||||
|
|
||||||
|
|
||||||
|
# ── CLI ───────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
p = argparse.ArgumentParser(description="Qwen3-ASR Unified API (streaming + non-streaming)")
|
||||||
|
p.add_argument("--asr-model-path", default="Qwen/Qwen3-ASR-1.7B", help="Model name or local path")
|
||||||
|
p.add_argument("--host", default="0.0.0.0")
|
||||||
|
p.add_argument("--port", type=int, default=8000)
|
||||||
|
p.add_argument("--gpu-memory-utilization", type=float, default=0.8)
|
||||||
|
p.add_argument("--max-new-tokens", type=int, default=32,
|
||||||
|
help="Max new tokens per call (streaming). Use larger value for non-streaming.")
|
||||||
|
p.add_argument("--unfixed-chunk-num", type=int, default=4)
|
||||||
|
p.add_argument("--unfixed-token-num", type=int, default=5)
|
||||||
|
p.add_argument("--chunk-size-sec", type=float, default=1.0)
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
server_args = parse_args()
|
||||||
|
uvicorn.run(app, host=server_args.host, port=server_args.port)
|
||||||
@ -7,10 +7,9 @@ name = "qwen-asr"
|
|||||||
version = "0.0.6"
|
version = "0.0.6"
|
||||||
description = "Qwen-ASR python package"
|
description = "Qwen-ASR python package"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.10"
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
|
|||||||
112
test.py
Normal file
112
test.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def stream_audio_to_api(uri: str, audio_path: str, chunk_size_ms: int = 500):
|
||||||
|
"""
|
||||||
|
Load audio and stream it in chunks to the ASR WebSocket API.
|
||||||
|
"""
|
||||||
|
logger.info(f"Loading audio from {audio_path}...")
|
||||||
|
|
||||||
|
# Load audio data
|
||||||
|
if audio_path.startswith("http"):
|
||||||
|
# Download from URL
|
||||||
|
req = urllib.request.Request(audio_path, headers={"User-Agent": "Mozilla/5.0"})
|
||||||
|
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||||
|
audio_bytes = resp.read()
|
||||||
|
f = io.BytesIO(audio_bytes)
|
||||||
|
else:
|
||||||
|
# Load local file
|
||||||
|
f = audio_path
|
||||||
|
|
||||||
|
# Read audio as Float32
|
||||||
|
wav, sr = sf.read(f, dtype="float32", always_2d=False)
|
||||||
|
|
||||||
|
# Simple resample to 16k if needed (for better accuracy)
|
||||||
|
if sr != 16000:
|
||||||
|
logger.warning(f"Audio sample rate is {sr}, resampling to 16000...")
|
||||||
|
dur = wav.shape[0] / float(sr)
|
||||||
|
n16 = int(round(dur * 16000))
|
||||||
|
x_old = np.linspace(0.0, dur, num=wav.shape[0], endpoint=False)
|
||||||
|
x_new = np.linspace(0.0, dur, num=n16, endpoint=False)
|
||||||
|
wav = np.interp(x_new, x_old, wav).astype(np.float32)
|
||||||
|
sr = 16000
|
||||||
|
|
||||||
|
# Calculate samples per chunk
|
||||||
|
chunk_samples = int(sr * chunk_size_ms / 1000)
|
||||||
|
|
||||||
|
logger.info(f"Connecting to WebSocket at {uri}...")
|
||||||
|
try:
|
||||||
|
async with websockets.connect(uri) as websocket:
|
||||||
|
logger.info("Connected. Streaming audio...")
|
||||||
|
|
||||||
|
pos = 0
|
||||||
|
call_id = 0
|
||||||
|
while pos < len(wav):
|
||||||
|
chunk = wav[pos : pos + chunk_samples]
|
||||||
|
pos += len(chunk)
|
||||||
|
call_id += 1
|
||||||
|
|
||||||
|
# Send binary Float32 data
|
||||||
|
await websocket.send(chunk.tobytes())
|
||||||
|
|
||||||
|
# Wait for immediate response (intermediate result)
|
||||||
|
try:
|
||||||
|
response = await asyncio.wait_for(websocket.recv(), timeout=2.0)
|
||||||
|
result = json.loads(response)
|
||||||
|
if "error" in result:
|
||||||
|
logger.error(f"API Error: {result['error']}")
|
||||||
|
return
|
||||||
|
|
||||||
|
lang = result.get("language", "unknown")
|
||||||
|
text = result.get("text", "")
|
||||||
|
print(f"[Chunk {call_id:03d}] Lang: {lang:7s} | Text: {text}")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"Timeout waiting for response on chunk {call_id}")
|
||||||
|
|
||||||
|
# Optional: simulate real-time performance
|
||||||
|
# await asyncio.sleep(chunk_size_ms / 1000)
|
||||||
|
|
||||||
|
# Send finish command
|
||||||
|
logger.info("Finished streaming audio. Sending 'finish' command...")
|
||||||
|
await websocket.send(json.dumps({"command": "finish"}))
|
||||||
|
|
||||||
|
# Wait for final response
|
||||||
|
try:
|
||||||
|
final_response = await asyncio.wait_for(websocket.recv(), timeout=5.0)
|
||||||
|
final_result = json.loads(final_response)
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print("FINAL RESULT:")
|
||||||
|
print(f"Language: {final_result.get('language')}")
|
||||||
|
print(f"Text: {final_result.get('text')}")
|
||||||
|
print("="*50)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("Timeout waiting for final response")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket Error: {e}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Qwen3-ASR Streaming API Client Test")
|
||||||
|
parser.add_argument("--url", default="ws://localhost:8000/asr/stream", help="WebSocket API URI")
|
||||||
|
parser.add_argument("--audio", default="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_en.wav",
|
||||||
|
help="Path or URL to audio file")
|
||||||
|
parser.add_argument("--chunk-ms", type=int, default=1000, help="Chunk size in milliseconds")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
asyncio.run(stream_audio_to_api(args.url, args.audio, args.chunk_ms))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user