|
|
@@ -2,7 +2,7 @@ import os
|
|
|
import logging
|
|
|
import pandas as pd
|
|
|
from pathlib import Path
|
|
|
-from typing import List, Tuple
|
|
|
+from typing import List, Tuple, Union
|
|
|
from mylib.pdfzh_translator import OpenAITranslator
|
|
|
from mylib.read_encoding_cvs import read_csv
|
|
|
from mylib.logging_config import setup_logging
|
|
|
@@ -11,12 +11,12 @@ from mylib.logging_config import setup_logging
|
|
|
setup_logging()
|
|
|
logger = logging.getLogger('mylib.translate_utils')
|
|
|
|
|
|
-def extract_column_data(df: pd.DataFrame, column_name: str, start_row: int = 2) -> pd.Series:
|
|
|
+def extract_column_data(df: pd.DataFrame, column_name: Union[str, int], start_row: int = 2) -> pd.Series:
|
|
|
"""提取指定列的数据,默认从第3行开始
|
|
|
|
|
|
Args:
|
|
|
df: pandas DataFrame
|
|
|
- column_name: 要提取的列名
|
|
|
+ column_name: 要提取的列名或列号(从0开始)
|
|
|
start_row: 开始提取的行号,默认为2(第3行)
|
|
|
|
|
|
Returns:
|
|
|
@@ -26,6 +26,12 @@ def extract_column_data(df: pd.DataFrame, column_name: str, start_row: int = 2)
|
|
|
if df.empty:
|
|
|
return pd.Series()
|
|
|
|
|
|
+ # 处理列号或列名
|
|
|
+ if isinstance(column_name, int):
|
|
|
+ if column_name < 0 or column_name >= len(df.columns):
|
|
|
+ raise ValueError(f"列号 {column_name} 超出范围")
|
|
|
+ column_name = df.columns[column_name]
|
|
|
+
|
|
|
# 确保列名存在
|
|
|
if column_name not in df.columns:
|
|
|
raise ValueError(f"列名 {column_name} 不存在")
|