translator.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. from __future__ import annotations
  2. import os
  3. from dataclasses import dataclass
  4. from pathlib import Path
  5. from typing import Iterable
  6. from huggingface_hub import snapshot_download
  7. from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
  8. try:
  9. from openai import OpenAI
  10. except ImportError:
  11. OpenAI = None # 延迟到选择 OpenAI backend 时再报错
  12. NLLB_MODEL_NAME = "facebook/nllb-200-distilled-600M"
  13. DEFAULT_MODEL_CACHE = Path(".models") / "nllb-200"
  14. SOURCE_LANG_CODE = "zho_Hans"
  15. SUPPORTED_AI_LANGS = {
  16. "en": {"nllb": "eng_Latn", "label": "译文(EN)"},
  17. "jp": {"nllb": "jpn_Jpan", "label": "訳文(JP)"},
  18. "korean": {"nllb": "kor_Hang", "label": "번译(KOREAN)"},
  19. }
  20. class TranslationError(Exception):
  21. """Generic translation failure."""
  22. @dataclass
  23. class TranslatorConfig:
  24. backend: str = "local"
  25. openai_model: str = "gpt-4o-mini"
  26. openai_api_key: str | None = None
  27. model_cache_dir: Path = DEFAULT_MODEL_CACHE
  28. batch_size: int = 4
  29. max_length: int = 512
  30. class BaseTranslator:
  31. def __init__(self) -> None:
  32. self._cache: dict[tuple[str, str], str] = {}
  33. def translate(self, texts: Iterable[str], target_lang: str) -> list[str]:
  34. normalized_lang = target_lang.lower()
  35. missing: list[str] = []
  36. order: list[int] = []
  37. results: list[str] = []
  38. for idx, text in enumerate(texts):
  39. key = (text, normalized_lang)
  40. if key in self._cache:
  41. results.append(self._cache[key])
  42. else:
  43. order.append(idx)
  44. missing.append(text)
  45. results.append("") # placeholder
  46. if missing:
  47. new_values = self._translate_impl(missing, normalized_lang)
  48. if len(new_values) != len(missing):
  49. raise TranslationError("翻译服务返回的结果数量与输入不匹配。")
  50. for idx, translated in enumerate(new_values):
  51. pos = order[idx]
  52. text = missing[idx]
  53. key = (text, normalized_lang)
  54. self._cache[key] = translated
  55. results[pos] = translated
  56. return results
  57. def translate_text(self, text: str, target_lang: str) -> str:
  58. return self.translate([text], target_lang)[0]
  59. def _translate_impl(self, texts: list[str], target_lang: str) -> list[str]:
  60. raise NotImplementedError
  61. class LocalNLLBTranslator(BaseTranslator):
  62. def __init__(self, config: TranslatorConfig) -> None:
  63. super().__init__()
  64. try:
  65. import torch
  66. except ImportError as exc:
  67. raise RuntimeError("需要安装 torch 才能加载本地 NLLB 模型。") from exc
  68. self.torch = torch
  69. self.config = config
  70. model_dir = self._ensure_model_files(config.model_cache_dir)
  71. self.tokenizer = AutoTokenizer.from_pretrained(
  72. model_dir,
  73. src_lang=SOURCE_LANG_CODE,
  74. use_fast=False,
  75. )
  76. self.tokenizer.src_lang = SOURCE_LANG_CODE
  77. self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
  78. self.device = self.torch.device(
  79. "cuda" if self.torch.cuda.is_available() else "cpu"
  80. )
  81. self.model.to(self.device)
  82. def _ensure_model_files(self, cache_dir: Path) -> Path:
  83. cache_dir = Path(cache_dir)
  84. cache_dir.mkdir(parents=True, exist_ok=True)
  85. config_file = cache_dir / "config.json"
  86. if not config_file.exists():
  87. snapshot_download(
  88. repo_id=NLLB_MODEL_NAME,
  89. local_dir=str(cache_dir),
  90. local_dir_use_symlinks=False,
  91. resume_download=True,
  92. )
  93. return cache_dir
  94. def _translate_impl(self, texts: list[str], target_lang: str) -> list[str]:
  95. if target_lang not in SUPPORTED_AI_LANGS:
  96. raise TranslationError(f"未支持的目标语言: {target_lang}")
  97. target_code = SUPPORTED_AI_LANGS[target_lang]["nllb"]
  98. forced_bos = self.tokenizer.convert_tokens_to_ids(target_code)
  99. if forced_bos == self.tokenizer.unk_token_id:
  100. raise TranslationError(f"NLLB 模型不支持目标语言代码: {target_code}")
  101. outputs: list[str] = []
  102. batch_size = max(1, self.config.batch_size)
  103. for start in range(0, len(texts), batch_size):
  104. batch = texts[start : start + batch_size]
  105. encoded = self.tokenizer(
  106. batch,
  107. return_tensors="pt",
  108. padding=True,
  109. truncation=True,
  110. max_length=self.config.max_length,
  111. ).to(self.device)
  112. with self.torch.no_grad():
  113. generated = self.model.generate(
  114. **encoded,
  115. forced_bos_token_id=forced_bos,
  116. max_length=self.config.max_length,
  117. )
  118. decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True)
  119. outputs.extend(decoded)
  120. return outputs
  121. class OpenAITranslator(BaseTranslator):
  122. def __init__(self, config: TranslatorConfig) -> None:
  123. super().__init__()
  124. if OpenAI is None:
  125. raise RuntimeError("需要安装 openai 包才可以使用 OpenAI backend。")
  126. api_key = config.openai_api_key or os.getenv("OPENAI_API_KEY")
  127. if not api_key:
  128. raise RuntimeError("缺少 OPENAI_API_KEY,无法连接 OpenAI API。")
  129. self.client = OpenAI(api_key=api_key)
  130. self.model = config.openai_model
  131. def _translate_impl(self, texts: list[str], target_lang: str) -> list[str]:
  132. if target_lang not in SUPPORTED_AI_LANGS:
  133. raise TranslationError(f"未支持的目标语言: {target_lang}")
  134. language_name = SUPPORTED_AI_LANGS[target_lang]["label"]
  135. results: list[str] = []
  136. for text in texts:
  137. response = self.client.chat.completions.create(
  138. model=self.model,
  139. messages=[
  140. {
  141. "role": "system",
  142. "content": "You are a professional video game localization translator.",
  143. },
  144. {
  145. "role": "user",
  146. "content": (
  147. f"Translate the following Simplified Chinese text to {language_name}. "
  148. "Preserve the tone, keep placeholders or tags unchanged, and return only the translation.\n"
  149. f"Text:\n{text}"
  150. ),
  151. },
  152. ],
  153. temperature=0.2,
  154. )
  155. translated = response.choices[0].message.content.strip()
  156. results.append(translated)
  157. return results
  158. def build_translator(config: TranslatorConfig) -> BaseTranslator:
  159. backend = config.backend.lower()
  160. if backend == "local":
  161. return LocalNLLBTranslator(config)
  162. if backend == "openai":
  163. return OpenAITranslator(config)
  164. raise ValueError(f"未知的翻译 backend: {config.backend}")