Pārlūkot izejas kodu

feat: add header and data search start parameters with default values

mrh (aider) 1 gadu atpakaļ
vecāks
revīzija
d27cd8d0f0
1 mainītis faili ar 37 papildinājumiem un 11 dzēšanām
  1. 37 11
      mylib/new_col_translate.py

+ 37 - 11
mylib/new_col_translate.py

@@ -51,7 +51,9 @@ def search_keywords(
     header: List[str],
     keywords: Union[str, List[str]],
     row_index: int = 0,
-    search_header: bool = False
+    search_header: bool = False,
+    header_search_start: int = 1,
+    data_search_start: int = 2
 ) -> List[str]:
     """搜索指定行中包含关键词的单元格并返回列名列表
     
@@ -61,6 +63,8 @@ def search_keywords(
         keywords: 要搜索的关键词
         row_index: 要搜索的行索引(0-based)
         search_header: 是否搜索表头行
+        header_search_start: 表头搜索起始行(1-based)
+        data_search_start: 数据搜索起始行(1-based)
     """
     if isinstance(keywords, str):
         keywords = [keywords]
@@ -69,24 +73,38 @@ def search_keywords(
     
     # 如果要搜索表头行
     if search_header:
+        # 检查header_search_start是否有效
+        if header_search_start < 1:
+            logger.warning(f"header_search_start {header_search_start} 无效,使用默认值1")
+            header_search_start = 1
+        
+        # 搜索表头行
         for col_index, cell in enumerate(header):
             if any(keyword in cell for keyword in keywords):
                 col_letter = index_to_column_letter(col_index)
                 found_columns.add(col_letter)
                 logger.debug(f"在表头 {col_letter} 列找到关键词: {cell}")
     else:
+        # 检查data_search_start是否有效
+        if data_search_start < 1:
+            logger.warning(f"data_search_start {data_search_start} 无效,使用默认值2")
+            data_search_start = 2
+        
+        # 计算实际行索引
+        actual_row_index = row_index + data_search_start - 1
+        
         # 检查行索引是否在数据范围内
-        if row_index >= len(data):
-            logger.warning(f"行索引 {row_index} 超出数据范围")
+        if actual_row_index >= len(data):
+            logger.warning(f"行索引 {actual_row_index} 超出数据范围")
             return []
         
         # 搜索数据行
-        row = data[row_index]
+        row = data[actual_row_index]
         for col_index, cell in enumerate(row):
             if any(keyword in cell for keyword in keywords):
                 col_letter = index_to_column_letter(col_index)
                 found_columns.add(col_letter)
-                logger.debug(f"在 {col_letter}{row_index + 2} 找到关键词: {cell}")
+                logger.debug(f"在 {col_letter}{actual_row_index + 2} 找到关键词: {cell}")
     
     found_columns = sorted(found_columns, key=lambda x: column_letter_to_index(x))
     logger.info(f"找到包含关键词的列: {', '.join(found_columns)}")
@@ -96,7 +114,7 @@ def translate_columns_data(
     data: List[List[str]],
     header: List[str],
     column_indices: List[int],
-    start_row: int = 1,
+    start_row: int = 2,  # 默认从第2行开始
     end_row: Optional[int] = None,
     source_lang: str = 'auto',
     target_lang: str = 'zh-CN'
@@ -109,7 +127,7 @@ def translate_columns_data(
     translator = OpenAITranslator(lang_out=target_lang, lang_in=source_lang)
     
     end_row = end_row if end_row is not None else len(data)
-    rows_to_translate = data[start_row:end_row]
+    rows_to_translate = data[start_row - 1:end_row]  # 转换为0-based索引
     
     logger.info(f"开始翻译 {start_row} 到 {end_row} 行的数据")
     
@@ -166,12 +184,14 @@ def process_csv(
     input_file: str,
     output_file: str,
     columns: Union[str, List[str]],
-    start_row: int = 1,
+    start_row: int = 2,  # 默认从第2行开始
     end_row: Optional[int] = None,
     source_lang: str = 'auto',
     target_lang: str = 'zh-CN',
     encoding: str = 'cp936',
-    header_row: int = 1
+    header_row: int = 1,
+    header_search_start: int = 1,  # 默认从第1行开始搜索表头
+    data_search_start: int = 2  # 默认从第2行开始搜索数据
 ):
     """处理CSV文件的主函数
     
@@ -185,6 +205,8 @@ def process_csv(
         target_lang: 目标语言
         encoding: 文件编码
         header_row: 表头所在行号(1-based)
+        header_search_start: 表头搜索起始行(1-based)
+        data_search_start: 数据搜索起始行(1-based)
     """
     try:
         # 记录用户传入的参数
@@ -193,6 +215,8 @@ def process_csv(
         logger.info(f"处理列:{columns}")
         logger.info(f"编码:{encoding}")
         logger.info(f"表头行号:{header_row}")
+        logger.info(f"表头搜索起始行:{header_search_start}")
+        logger.info(f"数据搜索起始行:{data_search_start}")
         
         # 转换列字母为索引
         if isinstance(columns, str):
@@ -243,8 +267,10 @@ if __name__ == "__main__":
     #     input_file=file_path,
     #     output_file=output_path,
     #     columns=found_columns,  # 使用搜索到的列
-    #     start_row=1,
+    #     start_row=2,
     #     source_lang='auto',
     #     target_lang='zh-CN',
-    #     header_row=1
+    #     header_row=1,
+    #     header_search_start=1,
+    #     data_search_start=2
     # )