Browse Source

add ssl for ws_clinet and 2pass offline srv (#546)

zhaomingwork 2 years ago
parent
commit
b8dd1e310b

+ 13 - 6
funasr/runtime/python/websocket/README.md

@@ -29,11 +29,13 @@ python ws_server_offline.py \
 --asr_model [asr model_name] \
 --punc_model [punc model_name] \
 --ngpu [0 or 1] \
---ncpu [1 or 4]
+--ncpu [1 or 4] \
+--certfile [path of certfile for ssl] \
+--keyfile [path of keyfile for ssl] 
 ```
 ##### Usage examples
 ```shell
-python ws_server_offline.py --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+python ws_server_offline.py --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" --certfile ./server.crt --keyfile ./server.key
 ```
 
 #### ASR streaming server
@@ -43,11 +45,13 @@ python ws_server_online.py \
 --port [port id] \
 --asr_model_online [asr model_name] \
 --ngpu [0 or 1] \
---ncpu [1 or 4]
+--ncpu [1 or 4] \
+--certfile [path of certfile for ssl] \
+--keyfile [path of keyfile for ssl] 
 ```
 ##### Usage examples
 ```shell
-python ws_server_online.py --port 10095 --asr_model_online "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
+python ws_server_online.py --port 10095 --asr_model_online "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online" --certfile ./server.crt --keyfile ./server.key
 ```
 
 #### ASR offline/online 2pass server
@@ -59,7 +63,9 @@ python ws_server_2pass.py \
 --asr_model_online [asr model_name] \
 --punc_model [punc model_name] \
 --ngpu [0 or 1] \
---ncpu [1 or 4]
+--ncpu [1 or 4] \
+--certfile [path of certfile for ssl] \
+--keyfile [path of keyfile for ssl] 
 ```
 ##### Usage examples
 ```shell
@@ -86,7 +92,8 @@ python ws_client.py \
 --words_max_print [max number of words to print] \
 --audio_in [if set, loadding from wav.scp, else recording from mircrophone] \
 --output_dir [if set, write the results to output_dir] \
---send_without_sleep [only set for offline]
+--send_without_sleep [only set for offline] \
+--ssl [1 for wss connect, 0 for ws, default is 1]
 ```
 #### Usage examples
 ##### ASR offline client

+ 16 - 3
funasr/runtime/python/websocket/ws_client.py

@@ -1,7 +1,7 @@
 # -*- encoding: utf-8 -*-
 import os
 import time
-import websockets
+import websockets,ssl
 import asyncio
 # import threading
 import argparse
@@ -53,6 +53,11 @@ parser.add_argument("--output_dir",
                     type=str,
                     default=None,
                     help="output_dir")
+                    
+parser.add_argument("--ssl",
+                    type=int,
+                    default=1,
+                    help="1 for ssl connect, 0 for no ssl")
 
 args = parser.parse_args()
 args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
@@ -221,8 +226,16 @@ async def print_messge():
 
 async def ws_client(id,chunk_begin,chunk_size):
     global websocket
-    uri = "ws://{}:{}".format(args.host, args.port)
-    async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None):
+    if  args.ssl==1:
+       ssl_context = ssl.SSLContext()
+       ssl_context.check_hostname = False
+       ssl_context.verify_mode = ssl.CERT_NONE
+       uri = "wss://{}:{}".format(args.host, args.port)
+    else:
+       uri = "ws://{}:{}".format(args.host, args.port)
+       ssl_context=None
+    print("connect to",uri)
+    async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None,ssl=ssl_context):
         if args.audio_in is not None:
             task = asyncio.create_task(record_from_scp(chunk_begin,chunk_size))
         else:

+ 11 - 2
funasr/runtime/python/websocket/ws_server_2pass.py

@@ -5,7 +5,7 @@ import time
 import logging
 import tracemalloc
 import numpy as np
-
+import ssl
 from parse_args import args
 from modelscope.pipelines import pipeline
 from modelscope.utils.constant import Tasks
@@ -191,7 +191,16 @@ async def async_asr_online(websocket, audio_in):
                 message = json.dumps({"mode": "2pass-online", "text": rec_result["text"], "wav_name": websocket.wav_name})
                 await websocket.send(message)
 
+if len(args.certfile)>0:
+	ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+	
+	# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
+	ssl_cert = args.certfile
+	ssl_key = args.keyfile
 
-start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
+	ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
+	start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
+else:
+	start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
 asyncio.get_event_loop().run_until_complete(start_server)
 asyncio.get_event_loop().run_forever()

+ 12 - 4
funasr/runtime/python/websocket/ws_server_offline.py

@@ -5,6 +5,7 @@ import time
 import logging
 import tracemalloc
 import numpy as np
+import ssl
 
 from parse_args import args
 from modelscope.pipelines import pipeline
@@ -147,9 +148,16 @@ async def async_asr(websocket, audio_in):
                 await websocket.send(message)
                 
                 
- 
-
-
-start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
+if len(args.certfile)>0:
+	ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+	
+	# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
+	ssl_cert = args.certfile
+	ssl_key = args.keyfile
+
+	ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
+	start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
+else:
+	start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
 asyncio.get_event_loop().run_until_complete(start_server)
 asyncio.get_event_loop().run_forever()