Sfoglia il codice sorgente

feat: implement column translation with new column insertion and CSV processing

mrh (aider) 1 anno fa
parent
commit
ab7a4c599a
1 ha cambiato i file con 59 aggiunte e 30 eliminazioni
  1. 59 30
      mylib/new_col_translate.py

+ 59 - 30
mylib/new_col_translate.py

@@ -67,22 +67,26 @@ def read_csv_with_header(file_path: str, header_row: int = 1, encoding: str = No
         logger.error(f"读取CSV文件时出错: {e}")
         raise
 
-def extract_column_data(df: pd.DataFrame, column_identifier: Union[str, int], start_row: int = 2, header_row: int = 1) -> pd.Series:
-    """提取指定列的数据,默认从第3行开始
+def translate_column_data(df: pd.DataFrame, column_identifier: Union[str, int], 
+                         start_row: int = 1, end_row: int = None,
+                         source_lang: str = 'auto', target_lang: str = 'zh-CN') -> pd.DataFrame:
+    """翻译指定列的数据并在右侧插入翻译结果列
     
     Args:
         df: pandas DataFrame
-        column_identifier: 要提取的列名或列号(从0开始),也可以是列字母(如 'A', 'B')
-        start_row: 开始提取的行号,默认为2(第3行)
-        header_row: 标题行号,默认为1(第2行)
+        column_identifier: 要翻译的列名或列号(从0开始),也可以是列字母(如 'A', 'B')
+        start_row: 开始翻译的行号,默认为1(第2行)
+        end_row: 结束翻译的行号,默认为None(到最后一行)
+        source_lang: 源语言代码,默认为'auto'
+        target_lang: 目标语言代码,默认为'zh-CN'
     
     Returns:
-        包含指定列数据的Series
+        包含翻译结果的DataFrame
     """
     try:
         if df.empty:
             logger.error("DataFrame为空")
-            return pd.Series()
+            return df
             
         # 处理列号或列名或列字母
         if isinstance(column_identifier, str) and column_identifier.isalpha():
@@ -98,43 +102,68 @@ def extract_column_data(df: pd.DataFrame, column_identifier: Union[str, int], st
             logger.error(f"列名 {column_identifier} 不存在")
             raise ValueError(f"列名 {column_identifier} 不存在")
             
-        # 确保开始行在有效范围内
-        if start_row >= len(df) or start_row < 0:
-            logger.error(f"开始行 {start_row} 超出范围")
-            raise ValueError(f"开始行 {start_row} 超出范围")
+        # 处理行范围
+        if end_row is None:
+            end_row = len(df)
+        if start_row < 0 or start_row >= len(df) or end_row < 0 or end_row > len(df):
+            logger.error(f"行范围 {start_row}-{end_row} 超出范围")
+            raise ValueError(f"行范围 {start_row}-{end_row} 超出范围")
             
-        # 提取指定列的数据
-        column_data = df.iloc[start_row:][column_identifier]
-        logger.info(f"成功提取列 {column_identifier} 数据,从第{start_row}行开始,共{len(column_data)}条数据")
-        return column_data
+        # 提取要翻译的数据
+        texts_to_translate = df.iloc[start_row:end_row][column_identifier].tolist()
+        logger.info(f"准备翻译 {len(texts_to_translate)} 条数据,从第{start_row}行到第{end_row}行")
+        
+        # 初始化翻译器
+        translator = OpenAITranslator(lang_out=target_lang, lang_in=source_lang)
+        
+        # 执行翻译
+        translated_texts = translator._batch_translate(texts_to_translate)
+        
+        # 在右侧插入新列
+        new_column_name = f"{column_identifier}_translated"
+        df.insert(df.columns.get_loc(column_identifier) + 1, new_column_name, "")
+        
+        # 填充翻译结果
+        df.loc[start_row:end_row-1, new_column_name] = translated_texts
+        
+        logger.info(f"翻译完成,已插入新列 {new_column_name}")
+        return df
         
     except Exception as e:
-        logger.error(f"提取列数据时出错: {e}")
+        logger.error(f"翻译列数据时出错: {e}")
         raise
 
-def test_column_extraction(input_file: str):
-    """测试列提取功能
+def process_csv(input_file: str, output_file: str, column_identifier: Union[str, int],
+               start_row: int = 1, end_row: int = None,
+               source_lang: str = 'auto', target_lang: str = 'zh-CN'):
+    """处理CSV文件并保存翻译结果
     
     Args:
         input_file: 输入CSV文件路径
+        output_file: 输出CSV文件路径
+        column_identifier: 要翻译的列名或列号(从0开始),也可以是列字母(如 'A', 'B')
+        start_row: 开始翻译的行号,默认为1(第2行)
+        end_row: 结束翻译的行号,默认为None(到最后一行)
+        source_lang: 源语言代码,默认为'auto'
+        target_lang: 目标语言代码,默认为'zh-CN'
     """
     try:
-        if not os.path.exists(input_file):
-            logger.error(f"文件不存在: {input_file}")
-            raise FileNotFoundError(f"文件不存在: {input_file}")
-        
         # 读取CSV文件
-        df = read_csv_with_header(input_file, header_row=1)
+        df = read_csv_with_header(input_file)
         
-        # 提取第二列的数据,从第三行开始
-        column_data = extract_column_data(df, column_identifier=1, start_row=2, header_row=1)
+        # 翻译指定列
+        df = translate_column_data(df, column_identifier, start_row, end_row, source_lang, target_lang)
         
-        # 打印提取的数据
-        print("提取的列数据:")
-        print(column_data)
+        # 保存结果
+        df.to_csv(output_file, index=False, encoding='utf-8-sig')
+        logger.info(f"翻译结果已保存到 {output_file}")
         
     except Exception as e:
-        logger.error(f"测试列提取时出错: {e}")
+        logger.error(f"处理CSV文件时出错: {e}")
+        raise
 
 if __name__ == '__main__':
-    input_file = Path('/home/mrh/code/excel_tool/temp/测试.csv')
+    # 示例用法
+    input_file = Path('/path/to/input.csv')
+    output_file = Path('/path/to/output.csv')
+    process_csv(input_file, output_file, column_identifier='B', start_row=1, end_row=10)