|
|
@@ -0,0 +1,194 @@
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import os
|
|
|
+from dataclasses import dataclass
|
|
|
+from pathlib import Path
|
|
|
+from typing import Iterable
|
|
|
+
|
|
|
+from huggingface_hub import snapshot_download
|
|
|
+from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
+
|
|
|
+try:
|
|
|
+ from openai import OpenAI
|
|
|
+except ImportError:
|
|
|
+ OpenAI = None # 延迟到选择 OpenAI backend 时再报错
|
|
|
+
|
|
|
+
|
|
|
+NLLB_MODEL_NAME = "facebook/nllb-200-distilled-600M"
|
|
|
+DEFAULT_MODEL_CACHE = Path(".models") / "nllb-200"
|
|
|
+SOURCE_LANG_CODE = "zho_Hans"
|
|
|
+SUPPORTED_AI_LANGS = {
|
|
|
+ "en": {"nllb": "eng_Latn", "label": "译文(EN)"},
|
|
|
+ "jp": {"nllb": "jpn_Jpan", "label": "訳文(JP)"},
|
|
|
+ "korean": {"nllb": "kor_Hang", "label": "번译(KOREAN)"},
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+class TranslationError(Exception):
|
|
|
+ """Generic translation failure."""
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class TranslatorConfig:
|
|
|
+ backend: str = "local"
|
|
|
+ openai_model: str = "gpt-4o-mini"
|
|
|
+ openai_api_key: str | None = None
|
|
|
+ model_cache_dir: Path = DEFAULT_MODEL_CACHE
|
|
|
+ batch_size: int = 4
|
|
|
+ max_length: int = 512
|
|
|
+
|
|
|
+
|
|
|
+class BaseTranslator:
|
|
|
+ def __init__(self) -> None:
|
|
|
+ self._cache: dict[tuple[str, str], str] = {}
|
|
|
+
|
|
|
+ def translate(self, texts: Iterable[str], target_lang: str) -> list[str]:
|
|
|
+ normalized_lang = target_lang.lower()
|
|
|
+ missing: list[str] = []
|
|
|
+ order: list[int] = []
|
|
|
+ results: list[str] = []
|
|
|
+
|
|
|
+ for idx, text in enumerate(texts):
|
|
|
+ key = (text, normalized_lang)
|
|
|
+ if key in self._cache:
|
|
|
+ results.append(self._cache[key])
|
|
|
+ else:
|
|
|
+ order.append(idx)
|
|
|
+ missing.append(text)
|
|
|
+ results.append("") # placeholder
|
|
|
+
|
|
|
+ if missing:
|
|
|
+ new_values = self._translate_impl(missing, normalized_lang)
|
|
|
+ if len(new_values) != len(missing):
|
|
|
+ raise TranslationError("翻译服务返回的结果数量与输入不匹配。")
|
|
|
+ for idx, translated in enumerate(new_values):
|
|
|
+ pos = order[idx]
|
|
|
+ text = missing[idx]
|
|
|
+ key = (text, normalized_lang)
|
|
|
+ self._cache[key] = translated
|
|
|
+ results[pos] = translated
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+ def translate_text(self, text: str, target_lang: str) -> str:
|
|
|
+ return self.translate([text], target_lang)[0]
|
|
|
+
|
|
|
+ def _translate_impl(self, texts: list[str], target_lang: str) -> list[str]:
|
|
|
+ raise NotImplementedError
|
|
|
+
|
|
|
+
|
|
|
+class LocalNLLBTranslator(BaseTranslator):
|
|
|
+ def __init__(self, config: TranslatorConfig) -> None:
|
|
|
+ super().__init__()
|
|
|
+ try:
|
|
|
+ import torch
|
|
|
+ except ImportError as exc:
|
|
|
+ raise RuntimeError("需要安装 torch 才能加载本地 NLLB 模型。") from exc
|
|
|
+
|
|
|
+ self.torch = torch
|
|
|
+ self.config = config
|
|
|
+ model_dir = self._ensure_model_files(config.model_cache_dir)
|
|
|
+ self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
+ model_dir,
|
|
|
+ src_lang=SOURCE_LANG_CODE,
|
|
|
+ use_fast=False,
|
|
|
+ )
|
|
|
+ self.tokenizer.src_lang = SOURCE_LANG_CODE
|
|
|
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
|
|
|
+ self.device = self.torch.device(
|
|
|
+ "cuda" if self.torch.cuda.is_available() else "cpu"
|
|
|
+ )
|
|
|
+ self.model.to(self.device)
|
|
|
+
|
|
|
+ def _ensure_model_files(self, cache_dir: Path) -> Path:
|
|
|
+ cache_dir = Path(cache_dir)
|
|
|
+ cache_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ config_file = cache_dir / "config.json"
|
|
|
+ if not config_file.exists():
|
|
|
+ snapshot_download(
|
|
|
+ repo_id=NLLB_MODEL_NAME,
|
|
|
+ local_dir=str(cache_dir),
|
|
|
+ local_dir_use_symlinks=False,
|
|
|
+ resume_download=True,
|
|
|
+ )
|
|
|
+ return cache_dir
|
|
|
+
|
|
|
+ def _translate_impl(self, texts: list[str], target_lang: str) -> list[str]:
|
|
|
+ if target_lang not in SUPPORTED_AI_LANGS:
|
|
|
+ raise TranslationError(f"未支持的目标语言: {target_lang}")
|
|
|
+
|
|
|
+ target_code = SUPPORTED_AI_LANGS[target_lang]["nllb"]
|
|
|
+ forced_bos = self.tokenizer.convert_tokens_to_ids(target_code)
|
|
|
+ if forced_bos == self.tokenizer.unk_token_id:
|
|
|
+ raise TranslationError(f"NLLB 模型不支持目标语言代码: {target_code}")
|
|
|
+
|
|
|
+ outputs: list[str] = []
|
|
|
+ batch_size = max(1, self.config.batch_size)
|
|
|
+ for start in range(0, len(texts), batch_size):
|
|
|
+ batch = texts[start : start + batch_size]
|
|
|
+ encoded = self.tokenizer(
|
|
|
+ batch,
|
|
|
+ return_tensors="pt",
|
|
|
+ padding=True,
|
|
|
+ truncation=True,
|
|
|
+ max_length=self.config.max_length,
|
|
|
+ ).to(self.device)
|
|
|
+ with self.torch.no_grad():
|
|
|
+ generated = self.model.generate(
|
|
|
+ **encoded,
|
|
|
+ forced_bos_token_id=forced_bos,
|
|
|
+ max_length=self.config.max_length,
|
|
|
+ )
|
|
|
+ decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True)
|
|
|
+ outputs.extend(decoded)
|
|
|
+ return outputs
|
|
|
+
|
|
|
+
|
|
|
+class OpenAITranslator(BaseTranslator):
|
|
|
+ def __init__(self, config: TranslatorConfig) -> None:
|
|
|
+ super().__init__()
|
|
|
+ if OpenAI is None:
|
|
|
+ raise RuntimeError("需要安装 openai 包才可以使用 OpenAI backend。")
|
|
|
+ api_key = config.openai_api_key or os.getenv("OPENAI_API_KEY")
|
|
|
+ if not api_key:
|
|
|
+ raise RuntimeError("缺少 OPENAI_API_KEY,无法连接 OpenAI API。")
|
|
|
+ self.client = OpenAI(api_key=api_key)
|
|
|
+ self.model = config.openai_model
|
|
|
+
|
|
|
+ def _translate_impl(self, texts: list[str], target_lang: str) -> list[str]:
|
|
|
+ if target_lang not in SUPPORTED_AI_LANGS:
|
|
|
+ raise TranslationError(f"未支持的目标语言: {target_lang}")
|
|
|
+
|
|
|
+ language_name = SUPPORTED_AI_LANGS[target_lang]["label"]
|
|
|
+ results: list[str] = []
|
|
|
+ for text in texts:
|
|
|
+ response = self.client.chat.completions.create(
|
|
|
+ model=self.model,
|
|
|
+ messages=[
|
|
|
+ {
|
|
|
+ "role": "system",
|
|
|
+ "content": "You are a professional video game localization translator.",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": (
|
|
|
+ f"Translate the following Simplified Chinese text to {language_name}. "
|
|
|
+ "Preserve the tone, keep placeholders or tags unchanged, and return only the translation.\n"
|
|
|
+ f"Text:\n{text}"
|
|
|
+ ),
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ temperature=0.2,
|
|
|
+ )
|
|
|
+ translated = response.choices[0].message.content.strip()
|
|
|
+ results.append(translated)
|
|
|
+ return results
|
|
|
+
|
|
|
+
|
|
|
+def build_translator(config: TranslatorConfig) -> BaseTranslator:
|
|
|
+ backend = config.backend.lower()
|
|
|
+ if backend == "local":
|
|
|
+ return LocalNLLBTranslator(config)
|
|
|
+ if backend == "openai":
|
|
|
+ return OpenAITranslator(config)
|
|
|
+ raise ValueError(f"未知的翻译 backend: {config.backend}")
|