from __future__ import annotations import os from dataclasses import dataclass from pathlib import Path from typing import Iterable from huggingface_hub import snapshot_download from transformers import AutoModelForSeq2SeqLM, AutoTokenizer try: from openai import OpenAI except ImportError: OpenAI = None # 延迟到选择 OpenAI backend 时再报错 NLLB_MODEL_NAME = "facebook/nllb-200-distilled-600M" DEFAULT_MODEL_CACHE = Path(".models") / "nllb-200" SOURCE_LANG_CODE = "zho_Hans" SUPPORTED_AI_LANGS = { "en": {"nllb": "eng_Latn", "label": "译文(EN)"}, "jp": {"nllb": "jpn_Jpan", "label": "訳文(JP)"}, "korean": {"nllb": "kor_Hang", "label": "번译(KOREAN)"}, } class TranslationError(Exception): """Generic translation failure.""" @dataclass class TranslatorConfig: backend: str = "local" openai_model: str = "gpt-4o-mini" openai_api_key: str | None = None model_cache_dir: Path = DEFAULT_MODEL_CACHE batch_size: int = 4 max_length: int = 512 class BaseTranslator: def __init__(self) -> None: self._cache: dict[tuple[str, str], str] = {} def translate(self, texts: Iterable[str], target_lang: str) -> list[str]: normalized_lang = target_lang.lower() missing: list[str] = [] order: list[int] = [] results: list[str] = [] for idx, text in enumerate(texts): key = (text, normalized_lang) if key in self._cache: results.append(self._cache[key]) else: order.append(idx) missing.append(text) results.append("") # placeholder if missing: new_values = self._translate_impl(missing, normalized_lang) if len(new_values) != len(missing): raise TranslationError("翻译服务返回的结果数量与输入不匹配。") for idx, translated in enumerate(new_values): pos = order[idx] text = missing[idx] key = (text, normalized_lang) self._cache[key] = translated results[pos] = translated return results def translate_text(self, text: str, target_lang: str) -> str: return self.translate([text], target_lang)[0] def _translate_impl(self, texts: list[str], target_lang: str) -> list[str]: raise NotImplementedError class LocalNLLBTranslator(BaseTranslator): def __init__(self, config: TranslatorConfig) -> None: super().__init__() try: import torch except ImportError as exc: raise RuntimeError("需要安装 torch 才能加载本地 NLLB 模型。") from exc self.torch = torch self.config = config model_dir = self._ensure_model_files(config.model_cache_dir) self.tokenizer = AutoTokenizer.from_pretrained( model_dir, src_lang=SOURCE_LANG_CODE, use_fast=False, ) self.tokenizer.src_lang = SOURCE_LANG_CODE self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) self.device = self.torch.device( "cuda" if self.torch.cuda.is_available() else "cpu" ) self.model.to(self.device) def _ensure_model_files(self, cache_dir: Path) -> Path: cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) config_file = cache_dir / "config.json" if not config_file.exists(): snapshot_download( repo_id=NLLB_MODEL_NAME, local_dir=str(cache_dir), local_dir_use_symlinks=False, resume_download=True, ) return cache_dir def _translate_impl(self, texts: list[str], target_lang: str) -> list[str]: if target_lang not in SUPPORTED_AI_LANGS: raise TranslationError(f"未支持的目标语言: {target_lang}") target_code = SUPPORTED_AI_LANGS[target_lang]["nllb"] forced_bos = self.tokenizer.convert_tokens_to_ids(target_code) if forced_bos == self.tokenizer.unk_token_id: raise TranslationError(f"NLLB 模型不支持目标语言代码: {target_code}") outputs: list[str] = [] batch_size = max(1, self.config.batch_size) for start in range(0, len(texts), batch_size): batch = texts[start : start + batch_size] encoded = self.tokenizer( batch, return_tensors="pt", padding=True, truncation=True, max_length=self.config.max_length, ).to(self.device) with self.torch.no_grad(): generated = self.model.generate( **encoded, forced_bos_token_id=forced_bos, max_length=self.config.max_length, ) decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True) outputs.extend(decoded) return outputs class OpenAITranslator(BaseTranslator): def __init__(self, config: TranslatorConfig) -> None: super().__init__() if OpenAI is None: raise RuntimeError("需要安装 openai 包才可以使用 OpenAI backend。") api_key = config.openai_api_key or os.getenv("OPENAI_API_KEY") if not api_key: raise RuntimeError("缺少 OPENAI_API_KEY,无法连接 OpenAI API。") self.client = OpenAI(api_key=api_key) self.model = config.openai_model def _translate_impl(self, texts: list[str], target_lang: str) -> list[str]: if target_lang not in SUPPORTED_AI_LANGS: raise TranslationError(f"未支持的目标语言: {target_lang}") language_name = SUPPORTED_AI_LANGS[target_lang]["label"] results: list[str] = [] for text in texts: response = self.client.chat.completions.create( model=self.model, messages=[ { "role": "system", "content": "You are a professional video game localization translator.", }, { "role": "user", "content": ( f"Translate the following Simplified Chinese text to {language_name}. " "Preserve the tone, keep placeholders or tags unchanged, and return only the translation.\n" f"Text:\n{text}" ), }, ], temperature=0.2, ) translated = response.choices[0].message.content.strip() results.append(translated) return results def build_translator(config: TranslatorConfig) -> BaseTranslator: backend = config.backend.lower() if backend == "local": return LocalNLLBTranslator(config) if backend == "openai": return OpenAITranslator(config) raise ValueError(f"未知的翻译 backend: {config.backend}")