video_asr_prefect.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import base64
  2. import hashlib
  3. import os
  4. import json
  5. import asyncio
  6. import prefect
  7. from typing import Optional
  8. from prefect import Flow, task
  9. import httpx
  10. from prefect import flow, task
  11. from pydantic import BaseModel
  12. from video_get import get_iframe_by_item_id, get_video_download_urls, download_video
  13. from exec_asr_client import run_asr_client # 导入语音转文本的函数
  14. from grpc_m.client import save_doc_vector
  15. from enum import Enum, auto
  16. from prefect.tasks import task_input_hash
  17. from datetime import timedelta
  18. from config import TEMP_DIR,logger
  19. from prefect.filesystems import LocalFileSystem, S3
  20. # TEMP_DIR = os.path.join("/home/user/code/open-douyin", "temp")
  21. from prefect import runtime
  22. from prefect.states import Completed, Failed
  23. VIDEO_TEMP_DIR = os.path.join(TEMP_DIR, "video")
  24. def cache_key_byfilename(data, args):
  25. file_name = runtime.flow_run.parameters.get('file_name')
  26. name = runtime.task_run.name
  27. logger.info(f"{file_name}_{name}")
  28. return f"{file_name}_{name}"
  29. @task
  30. async def video_process_init(video_temp_dir):
  31. if not os.path.exists(video_temp_dir):
  32. os.makedirs(video_temp_dir)
  33. return video_temp_dir
  34. # @task(cache_key_fn=cache_key_byfilename, cache_expiration=timedelta(minutes=50))
  35. @task(cache_key_fn=task_input_hash)
  36. async def get_iframe_task(item_id):
  37. logger.info(f"获取播放页 iframe 代码,item_id: {item_id}")
  38. response = await get_iframe_by_item_id(item_id)
  39. # {'data': {'iframe_code':""}}
  40. iframe = response.get("data").get("iframe_code")
  41. if not iframe:
  42. logger.error(f"获取 iframe 失败, {response} - {item_id}")
  43. return iframe
  44. # 获取无水印链接相当于加载 iframe ,不会增加多少网络负载或损耗,因此不推荐缓存,况且无水印链接 return 时常会变,也不建议缓存
  45. @task
  46. async def get_urls_task(iframe):
  47. logger.info(f"获取视频无水印下载链接,iframe: {iframe}")
  48. response = await get_video_download_urls(iframe)
  49. urls = response.get("data")
  50. return urls
  51. def custom_cache_download_task(data, args):
  52. file_name = runtime.flow_run.parameters.get('file_name')
  53. save_file_path = args.get('save_file_path')
  54. key = f"{file_name}_{save_file_path}"
  55. return key
  56. @task(cache_key_fn=custom_cache_download_task)
  57. async def download_task(urls, save_file_path):
  58. logger.info(f"runtime.flow_run.parameters {runtime.flow_run.parameters}")
  59. logger.info(f"下载视频 urls : {urls}, save_file_path : {save_file_path}")
  60. res = await download_video(urls, save_file_path)
  61. return res
  62. # @task(cache_key_fn=task_input_hash, cache_expiration=timedelta(minutes=50))
  63. @task
  64. async def asr_task(audio_file, output_json_path):
  65. # 假设 run_asr_client 返回一个包含转录文本的 JSON 字符串或路径
  66. if await run_asr_client(audio_file, output_json_path): # None 是占位符,根据你的实现进行调整
  67. return output_json_path # 或者返回实际的转录文本,取决于你的需求
  68. else:
  69. raise Exception("err")
  70. @task
  71. async def vector_task(output_json_path, collection_name, doc_path, server_addr):
  72. video_asr_json = json.load(open(output_json_path))
  73. asr_text = video_asr_json.get("text")
  74. response = await save_doc_vector(collection_name, doc_path, server_addr) # 这里可能需要调整参数以匹配你的实现
  75. return response
  76. @flow
  77. async def get_download_url(item_id):
  78. logger.info(f"runtime.flow_run.parameters {runtime.flow_run.parameters}")
  79. logger.info(f"runtime.flow_run.id {runtime.flow_run.id}")
  80. logger.info(f"runtime.flow_run.flow_name {runtime.flow_run.flow_name}")
  81. iframe = await get_iframe_task(item_id)
  82. if not iframe:
  83. return Failed(message=f"获取 iframe 失败,item_id: {item_id}")
  84. urls = await get_urls_task(iframe)
  85. return urls
  86. def generate_result_storage():
  87. flow_name = runtime.flow_run.flow_name
  88. parameters = runtime.flow_run.parameters
  89. file_name = parameters["file_name"]
  90. return os.path.join(VIDEO_TEMP_DIR, file_name)
  91. @flow(flow_run_name="{file_name}")
  92. async def download_video_flow(file_name , item_id, refresh_cache=False):
  93. get_iframe_task.refresh_cache = refresh_cache
  94. download_task.refresh_cache = refresh_cache
  95. logger.info(f"runtime.flow_run.parameters {runtime.flow_run.parameters}")
  96. logger.info(f"runtime.flow_run.id {runtime.flow_run.id}")
  97. logger.info(f"runtime.flow_run.flow_name {runtime.flow_run.flow_name}")
  98. urls = await get_download_url(item_id)
  99. save_video_path = os.path.join(VIDEO_TEMP_DIR, file_name + ".mp4")
  100. path = await download_task(urls, save_video_path)
  101. if not path:
  102. return Failed(message=f"下载视频失败,urls : {urls}, save_file_path : {path}")
  103. logger.info(f"download urls: {urls}")
  104. logger.info(f"download path: {path}")
  105. # logger.info(f"Cache location: {cache_location}")
  106. # return path,urls
  107. # iframe = await get_iframe_task(item_id)
  108. # urls = await get_urls_task(iframe)
  109. # save_file_path = os.path.join(video_temp_dir, f"{item_id}.mp4") # 示例文件名
  110. # downloaded_file_path = await download_task(urls, save_file_path) # 确保 download_task 正确返回下载的文件路径
  111. # # 添加语音转文本任务
  112. # output_json_path = await asr_task(downloaded_file_path)
  113. # # 添加向量转换存储任务(这里我们假设转录文本直接用于向量转换,但你可能需要额外的处理)
  114. # vector_response = await vector_task(output_json_path, "your_collection_name", output_json_path) # 替换 "your_collection_name" 为实际的集合名称
  115. # 这里可以添加更多的任务处理或返回结果等逻辑...
  116. # return vector_response
  117. # 运行流
  118. if __name__ == "__main__":
  119. import asyncio
  120. # item_id = "@9VxS1/qCUc80K2etd8wkUc7912DgP/GCPpF2qwKuJ1YTb/X460zdRmYqig357zEBKzkoKYjBMUvl9Bs6h+CwYQ==" # 从某处获取实际的 item_id
  121. item_id = "@9VxS1/qCUc80K2etd8wkUc791mfoNf+EMpZzqQKiLVIaaPD660zdRmYqig357zEBdcJfEwgOpm1bLVAnSdwvLg=="
  122. asyncio.run(download_video_flow("11", item_id))