Sfoglia il codice sorgente

feat: support column numbers in extract_column_data function

mrh (aider) 1 anno fa
parent
commit
23fa1ed3bf
1 ha cambiato i file con 9 aggiunte e 3 eliminazioni
  1. 9 3
      mylib/translate_utils.py

+ 9 - 3
mylib/translate_utils.py

@@ -2,7 +2,7 @@ import os
 import logging
 import pandas as pd
 from pathlib import Path
-from typing import List, Tuple
+from typing import List, Tuple, Union
 from mylib.pdfzh_translator import OpenAITranslator
 from mylib.read_encoding_cvs import read_csv
 from mylib.logging_config import setup_logging
@@ -11,12 +11,12 @@ from mylib.logging_config import setup_logging
 setup_logging()
 logger = logging.getLogger('mylib.translate_utils')
 
-def extract_column_data(df: pd.DataFrame, column_name: str, start_row: int = 2) -> pd.Series:
+def extract_column_data(df: pd.DataFrame, column_name: Union[str, int], start_row: int = 2) -> pd.Series:
     """提取指定列的数据,默认从第3行开始
     
     Args:
         df: pandas DataFrame
-        column_name: 要提取的列名
+        column_name: 要提取的列名或列号(从0开始)
         start_row: 开始提取的行号,默认为2(第3行)
     
     Returns:
@@ -26,6 +26,12 @@ def extract_column_data(df: pd.DataFrame, column_name: str, start_row: int = 2)
         if df.empty:
             return pd.Series()
             
+        # 处理列号或列名
+        if isinstance(column_name, int):
+            if column_name < 0 or column_name >= len(df.columns):
+                raise ValueError(f"列号 {column_name} 超出范围")
+            column_name = df.columns[column_name]
+            
         # 确保列名存在
         if column_name not in df.columns:
             raise ValueError(f"列名 {column_name} 不存在")