| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194 |
- from __future__ import annotations
- import os
- from dataclasses import dataclass
- from pathlib import Path
- from typing import Iterable
- from huggingface_hub import snapshot_download
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
- try:
- from openai import OpenAI
- except ImportError:
- OpenAI = None # 延迟到选择 OpenAI backend 时再报错
- NLLB_MODEL_NAME = "facebook/nllb-200-distilled-600M"
- DEFAULT_MODEL_CACHE = Path(".models") / "nllb-200"
- SOURCE_LANG_CODE = "zho_Hans"
- SUPPORTED_AI_LANGS = {
- "en": {"nllb": "eng_Latn", "label": "译文(EN)"},
- "jp": {"nllb": "jpn_Jpan", "label": "訳文(JP)"},
- "korean": {"nllb": "kor_Hang", "label": "번译(KOREAN)"},
- }
- class TranslationError(Exception):
- """Generic translation failure."""
- @dataclass
- class TranslatorConfig:
- backend: str = "local"
- openai_model: str = "gpt-4o-mini"
- openai_api_key: str | None = None
- model_cache_dir: Path = DEFAULT_MODEL_CACHE
- batch_size: int = 4
- max_length: int = 512
- class BaseTranslator:
- def __init__(self) -> None:
- self._cache: dict[tuple[str, str], str] = {}
- def translate(self, texts: Iterable[str], target_lang: str) -> list[str]:
- normalized_lang = target_lang.lower()
- missing: list[str] = []
- order: list[int] = []
- results: list[str] = []
- for idx, text in enumerate(texts):
- key = (text, normalized_lang)
- if key in self._cache:
- results.append(self._cache[key])
- else:
- order.append(idx)
- missing.append(text)
- results.append("") # placeholder
- if missing:
- new_values = self._translate_impl(missing, normalized_lang)
- if len(new_values) != len(missing):
- raise TranslationError("翻译服务返回的结果数量与输入不匹配。")
- for idx, translated in enumerate(new_values):
- pos = order[idx]
- text = missing[idx]
- key = (text, normalized_lang)
- self._cache[key] = translated
- results[pos] = translated
- return results
- def translate_text(self, text: str, target_lang: str) -> str:
- return self.translate([text], target_lang)[0]
- def _translate_impl(self, texts: list[str], target_lang: str) -> list[str]:
- raise NotImplementedError
- class LocalNLLBTranslator(BaseTranslator):
- def __init__(self, config: TranslatorConfig) -> None:
- super().__init__()
- try:
- import torch
- except ImportError as exc:
- raise RuntimeError("需要安装 torch 才能加载本地 NLLB 模型。") from exc
- self.torch = torch
- self.config = config
- model_dir = self._ensure_model_files(config.model_cache_dir)
- self.tokenizer = AutoTokenizer.from_pretrained(
- model_dir,
- src_lang=SOURCE_LANG_CODE,
- use_fast=False,
- )
- self.tokenizer.src_lang = SOURCE_LANG_CODE
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
- self.device = self.torch.device(
- "cuda" if self.torch.cuda.is_available() else "cpu"
- )
- self.model.to(self.device)
- def _ensure_model_files(self, cache_dir: Path) -> Path:
- cache_dir = Path(cache_dir)
- cache_dir.mkdir(parents=True, exist_ok=True)
- config_file = cache_dir / "config.json"
- if not config_file.exists():
- snapshot_download(
- repo_id=NLLB_MODEL_NAME,
- local_dir=str(cache_dir),
- local_dir_use_symlinks=False,
- resume_download=True,
- )
- return cache_dir
- def _translate_impl(self, texts: list[str], target_lang: str) -> list[str]:
- if target_lang not in SUPPORTED_AI_LANGS:
- raise TranslationError(f"未支持的目标语言: {target_lang}")
- target_code = SUPPORTED_AI_LANGS[target_lang]["nllb"]
- forced_bos = self.tokenizer.convert_tokens_to_ids(target_code)
- if forced_bos == self.tokenizer.unk_token_id:
- raise TranslationError(f"NLLB 模型不支持目标语言代码: {target_code}")
- outputs: list[str] = []
- batch_size = max(1, self.config.batch_size)
- for start in range(0, len(texts), batch_size):
- batch = texts[start : start + batch_size]
- encoded = self.tokenizer(
- batch,
- return_tensors="pt",
- padding=True,
- truncation=True,
- max_length=self.config.max_length,
- ).to(self.device)
- with self.torch.no_grad():
- generated = self.model.generate(
- **encoded,
- forced_bos_token_id=forced_bos,
- max_length=self.config.max_length,
- )
- decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True)
- outputs.extend(decoded)
- return outputs
- class OpenAITranslator(BaseTranslator):
- def __init__(self, config: TranslatorConfig) -> None:
- super().__init__()
- if OpenAI is None:
- raise RuntimeError("需要安装 openai 包才可以使用 OpenAI backend。")
- api_key = config.openai_api_key or os.getenv("OPENAI_API_KEY")
- if not api_key:
- raise RuntimeError("缺少 OPENAI_API_KEY,无法连接 OpenAI API。")
- self.client = OpenAI(api_key=api_key)
- self.model = config.openai_model
- def _translate_impl(self, texts: list[str], target_lang: str) -> list[str]:
- if target_lang not in SUPPORTED_AI_LANGS:
- raise TranslationError(f"未支持的目标语言: {target_lang}")
- language_name = SUPPORTED_AI_LANGS[target_lang]["label"]
- results: list[str] = []
- for text in texts:
- response = self.client.chat.completions.create(
- model=self.model,
- messages=[
- {
- "role": "system",
- "content": "You are a professional video game localization translator.",
- },
- {
- "role": "user",
- "content": (
- f"Translate the following Simplified Chinese text to {language_name}. "
- "Preserve the tone, keep placeholders or tags unchanged, and return only the translation.\n"
- f"Text:\n{text}"
- ),
- },
- ],
- temperature=0.2,
- )
- translated = response.choices[0].message.content.strip()
- results.append(translated)
- return results
- def build_translator(config: TranslatorConfig) -> BaseTranslator:
- backend = config.backend.lower()
- if backend == "local":
- return LocalNLLBTranslator(config)
- if backend == "openai":
- return OpenAITranslator(config)
- raise ValueError(f"未知的翻译 backend: {config.backend}")
|