textRecognizer.js 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. // server/utils/textRecognizer.js
  2. import { Tensor } from 'onnxruntime-node';
  3. import sharp from 'sharp';
  4. import fse from 'fs-extra';
  5. import * as path from 'path';
  6. class TextRecognizer {
  7. constructor() {
  8. this.recSession = null;
  9. this.config = null;
  10. this.characterSet = [];
  11. this.debugDir = path.join(process.cwd(), 'temp', 'debug');
  12. this.preprocessedDir = path.join(process.cwd(), 'temp', 'preprocessed');
  13. this.logger = {
  14. info: (msg, ...args) => console.log(`🔤 [识别] ${msg}`, ...args),
  15. error: (msg, ...args) => console.error(`❌ [识别] ${msg}`, ...args),
  16. debug: (msg, ...args) => console.log(`🐛 [识别] ${msg}`, ...args),
  17. warn: (msg, ...args) => console.warn(`🐛 [识别] ${msg}`, ...args)
  18. };
  19. // 确保目录存在
  20. fse.ensureDirSync(this.debugDir);
  21. fse.ensureDirSync(this.preprocessedDir);
  22. }
  23. initialize(recSession, config) {
  24. this.recSession = recSession;
  25. this.config = config;
  26. this.logger.info('文本识别器初始化完成');
  27. }
  28. async loadCharacterSet(keysPath) {
  29. try {
  30. const keysContent = await fse.readFile(keysPath, 'utf8');
  31. this.characterSet = [];
  32. const lines = keysContent.split('\n');
  33. // 使用提供的字符集文件
  34. const uniqueChars = new Set();
  35. for (const line of lines) {
  36. const trimmed = line.trim();
  37. // 跳过空行和注释行
  38. if (trimmed && !trimmed.startsWith('#')) {
  39. // 将每行作为一个完整的字符处理
  40. uniqueChars.add(trimmed);
  41. }
  42. }
  43. this.characterSet = Array.from(uniqueChars);
  44. if (this.characterSet.length === 0) {
  45. throw new Error('字符集文件为空或格式不正确');
  46. }
  47. this.logger.info(`字符集加载完成: ${this.characterSet.length}个字符`);
  48. // 记录字符集统计信息
  49. const charTypes = {
  50. chinese: 0,
  51. english: 0,
  52. digit: 0,
  53. punctuation: 0,
  54. other: 0
  55. };
  56. this.characterSet.forEach(char => {
  57. if (/[\u4e00-\u9fff]/.test(char)) {
  58. charTypes.chinese++;
  59. } else if (/[a-zA-Z]/.test(char)) {
  60. charTypes.english++;
  61. } else if (/[0-9]/.test(char)) {
  62. charTypes.digit++;
  63. } else if (/[,。!?;:""()【】《》…—·]/.test(char)) {
  64. charTypes.punctuation++;
  65. } else {
  66. charTypes.other++;
  67. }
  68. });
  69. this.logger.debug(`字符集统计: 中文${charTypes.chinese}, 英文${charTypes.english}, 数字${charTypes.digit}, 标点${charTypes.punctuation}, 其他${charTypes.other}`);
  70. this.logger.debug(`前20个字符: ${this.characterSet.slice(0, 20).join('')}`);
  71. } catch (error) {
  72. this.logger.error('加载字符集失败', error.message);
  73. // 完全使用提供的字符集,失败时抛出错误
  74. throw new Error(`字符集加载失败: ${error.message}`);
  75. }
  76. }
  77. getCharacterSetSize() {
  78. return this.characterSet.length;
  79. }
  80. async recognizeText(textRegionBuffer, regionIndex = 0) {
  81. const startTime = Date.now();
  82. this.logger.info(`开始文本识别 - 区域 ${regionIndex}`);
  83. try {
  84. const inputTensor = await this.prepareRecognitionInput(textRegionBuffer, regionIndex);
  85. const outputs = await this.recSession.run({ [this.recSession.inputNames[0]]: inputTensor });
  86. const result = this.postprocessRecognition(outputs);
  87. const processingTime = Date.now() - startTime;
  88. this.logger.info(`识别完成 - 区域 ${regionIndex}: "${result.text}", 置信度: ${result.confidence.toFixed(4)}, 耗时: ${processingTime}ms`);
  89. return result;
  90. } catch (error) {
  91. this.logger.error(`文本识别失败 - 区域 ${regionIndex}`, error);
  92. return { text: '', confidence: 0 };
  93. }
  94. }
  95. async prepareRecognitionInput(textRegionBuffer, regionIndex = 0) {
  96. this.logger.debug(`准备识别输入 - 区域 ${regionIndex}`);
  97. const targetHeight = 48;
  98. const targetWidth = 320; // 原始目标宽度
  99. const finalWidth = targetWidth + 20; // 最终宽度(左右各加10像素)
  100. const timestamp = Date.now();
  101. try {
  102. const metadata = await sharp(textRegionBuffer).metadata();
  103. this.logger.debug(`原始区域 ${regionIndex}: ${metadata.width}x${metadata.height}`);
  104. // 保存原始裁剪区域图像
  105. const originalPath = path.join(this.preprocessedDir, `region-${regionIndex}-original-${timestamp}.png`);
  106. await fse.writeFile(originalPath, textRegionBuffer);
  107. this.logger.debug(`保存原始区域图像: ${originalPath}`);
  108. // 图像分析
  109. const stats = await sharp(textRegionBuffer).grayscale().stats();
  110. const meanBrightness = stats.channels[0].mean;
  111. const stdDev = stats.channels[0].stdev;
  112. this.logger.debug(`图像统计 - 区域 ${regionIndex}: 亮度=${meanBrightness.toFixed(1)}, 对比度=${stdDev.toFixed(1)}`);
  113. // 智能预处理
  114. let processedBuffer = await this.applySmartPreprocessing(textRegionBuffer, meanBrightness, stdDev, regionIndex);
  115. // 保存预处理后的图像(灰度+对比度调整后)
  116. const processedPath = path.join(this.preprocessedDir, `region-${regionIndex}-processed-${timestamp}.png`);
  117. await fse.writeFile(processedPath, processedBuffer);
  118. this.logger.debug(`保存预处理图像: ${processedPath}`);
  119. // 保持宽高比的resize,并在左右添加10像素空白
  120. const resizedBuffer = await this.resizeWithAspectRatio(processedBuffer, targetWidth, targetHeight, regionIndex);
  121. // 保存调整大小后的图像
  122. const resizedPath = path.join(this.preprocessedDir, `region-${regionIndex}-resized-${timestamp}.png`);
  123. await fse.writeFile(resizedPath, resizedBuffer);
  124. this.logger.debug(`保存调整大小图像: ${resizedPath}`);
  125. // 使用最终尺寸创建张量
  126. const inputData = await this.bufferToTensor(resizedBuffer, finalWidth, targetHeight);
  127. this.logger.debug(`识别输入张量准备完成 - 区域 ${regionIndex}`);
  128. // 创建张量时使用最终尺寸
  129. return new Tensor('float32', inputData, [1, 3, targetHeight, finalWidth]);
  130. } catch (error) {
  131. this.logger.error(`准备识别输入失败 - 区域 ${regionIndex}`, error);
  132. return new Tensor('float32', new Float32Array(3 * targetHeight * finalWidth).fill(0.5), [1, 3, targetHeight, finalWidth]);
  133. }
  134. }
  135. async applySmartPreprocessing(buffer, meanBrightness, stdDev, regionIndex = 0) {
  136. let processedBuffer = buffer;
  137. if (meanBrightness > 200 && stdDev < 30) {
  138. this.logger.debug(`区域 ${regionIndex}: 应用高亮度图像增强`);
  139. processedBuffer = await sharp(buffer)
  140. .linear(1.5, -50)
  141. .normalize()
  142. .grayscale()
  143. .toBuffer();
  144. } else if (meanBrightness < 80) {
  145. this.logger.debug(`区域 ${regionIndex}: 应用低亮度图像增强`);
  146. processedBuffer = await sharp(buffer)
  147. .linear(1.2, 30)
  148. .normalize()
  149. .grayscale()
  150. .toBuffer();
  151. } else if (stdDev < 20) {
  152. this.logger.debug(`区域 ${regionIndex}: 应用低对比度增强`);
  153. processedBuffer = await sharp(buffer)
  154. .linear(1.3, -20)
  155. .normalize()
  156. .grayscale()
  157. .toBuffer();
  158. } else {
  159. this.logger.debug(`区域 ${regionIndex}: 应用标准化灰度处理`);
  160. processedBuffer = await sharp(buffer)
  161. .normalize()
  162. .grayscale()
  163. .toBuffer();
  164. }
  165. return processedBuffer;
  166. }
  167. async resizeWithAspectRatio(buffer, targetWidth, targetHeight, regionIndex = 0) {
  168. const metadata = await sharp(buffer).metadata();
  169. const originalAspectRatio = metadata.width / metadata.height;
  170. const targetAspectRatio = targetWidth / targetHeight;
  171. let resizeWidth, resizeHeight;
  172. if (originalAspectRatio > targetAspectRatio) {
  173. // 宽度限制,按宽度缩放
  174. resizeWidth = targetWidth;
  175. resizeHeight = Math.round(targetWidth / originalAspectRatio);
  176. } else {
  177. // 高度限制,按高度缩放
  178. resizeHeight = targetHeight;
  179. resizeWidth = Math.round(targetHeight * originalAspectRatio);
  180. }
  181. resizeWidth = Math.max(1, Math.min(resizeWidth, targetWidth));
  182. resizeHeight = Math.max(1, Math.min(resizeHeight, targetHeight));
  183. this.logger.debug(`区域 ${regionIndex}: 调整尺寸 ${metadata.width}x${metadata.height} -> ${resizeWidth}x${resizeHeight}`);
  184. // 计算居中的偏移量
  185. const offsetX = Math.floor((targetWidth - resizeWidth) / 2);
  186. const offsetY = Math.floor((targetHeight - resizeHeight) / 2);
  187. this.logger.debug(`区域 ${regionIndex}: 居中偏移 X=${offsetX}, Y=${offsetY}`);
  188. // 先调整大小并居中
  189. let resizedBuffer = await sharp(buffer)
  190. .resize(resizeWidth, resizeHeight, {
  191. fit: 'contain',
  192. background: { r: 255, g: 255, b: 255 }
  193. })
  194. .extend({
  195. top: offsetY,
  196. bottom: targetHeight - resizeHeight - offsetY,
  197. left: offsetX,
  198. right: targetWidth - resizeWidth - offsetX,
  199. background: { r: 255, g: 255, b: 255 }
  200. })
  201. .png()
  202. .toBuffer();
  203. // 在左右各添加10像素空白
  204. const finalWidth = targetWidth + 20; // 左右各加10像素
  205. const finalHeight = targetHeight;
  206. resizedBuffer = await sharp(resizedBuffer)
  207. .extend({
  208. top: 0,
  209. bottom: 0,
  210. left: 10,
  211. right: 10,
  212. background: { r: 255, g: 255, b: 255 }
  213. })
  214. .png()
  215. .toBuffer();
  216. this.logger.debug(`区域 ${regionIndex}: 最终尺寸 ${finalWidth}x${finalHeight} (左右各加10像素空白)`);
  217. return resizedBuffer;
  218. }
  219. async bufferToTensor(buffer, width, height) {
  220. // 获取实际图像尺寸(因为现在宽度增加了20像素)
  221. const metadata = await sharp(buffer).metadata();
  222. const actualWidth = metadata.width;
  223. const actualHeight = metadata.height;
  224. const imageData = await sharp(buffer)
  225. .ensureAlpha()
  226. .raw()
  227. .toBuffer({ resolveWithObject: true });
  228. // 使用实际尺寸创建张量
  229. const inputData = new Float32Array(3 * actualHeight * actualWidth);
  230. const data = imageData.data;
  231. for (let i = 0; i < data.length; i += 4) {
  232. const pixelIndex = Math.floor(i / 4);
  233. const y = Math.floor(pixelIndex / actualWidth);
  234. const x = pixelIndex % actualWidth;
  235. // 使用灰度值填充三个通道
  236. const grayValue = data[i] / 255.0;
  237. for (let c = 0; c < 3; c++) {
  238. const inputIndex = c * actualHeight * actualWidth + y * actualWidth + x;
  239. if (inputIndex < inputData.length) {
  240. inputData[inputIndex] = grayValue;
  241. }
  242. }
  243. }
  244. return inputData;
  245. }
  246. postprocessRecognition(outputs) {
  247. this.logger.debug('开始识别后处理');
  248. try {
  249. const outputNames = this.recSession.outputNames;
  250. const recognitionOutput = outputs[outputNames[0]];
  251. if (!recognitionOutput) {
  252. this.logger.debug('识别输出为空');
  253. return { text: '', confidence: 0 };
  254. }
  255. const data = recognitionOutput.data;
  256. const [batch, seqLen, vocabSize] = recognitionOutput.dims;
  257. this.logger.debug(`序列长度: ${seqLen}, 词汇表大小: ${vocabSize}, 字符集大小: ${this.characterSet.length}`);
  258. if (this.characterSet.length === 0) {
  259. this.logger.error('字符集为空');
  260. return { text: '', confidence: 0 };
  261. }
  262. // 验证词汇表大小与字符集大小的匹配
  263. if (vocabSize !== this.characterSet.length + 1) {
  264. this.logger.warn(`词汇表大小(${vocabSize})与字符集大小(${this.characterSet.length})不匹配,可能影响识别效果`);
  265. }
  266. const { text, confidence } = this.ctcDecode(data, seqLen, vocabSize);
  267. this.logger.debug(`解码结果: "${text}", 置信度: ${confidence.toFixed(4)}`);
  268. return { text, confidence };
  269. } catch (error) {
  270. this.logger.error('识别后处理失败', error);
  271. return { text: '', confidence: 0 };
  272. }
  273. }
  274. ctcDecode(data, seqLen, vocabSize) {
  275. let text = '';
  276. let lastCharIndex = -1;
  277. let confidenceSum = 0;
  278. let charCount = 0;
  279. // 动态阈值调整
  280. const baseThreshold = 0.03;
  281. let confidenceThreshold = baseThreshold;
  282. // 先分析整个序列的置信度分布
  283. let maxSequenceProb = 0;
  284. for (let t = 0; t < seqLen; t++) {
  285. for (let i = 0; i < vocabSize; i++) {
  286. maxSequenceProb = Math.max(maxSequenceProb, data[t * vocabSize + i]);
  287. }
  288. }
  289. // 如果整体置信度较低,降低阈值
  290. if (maxSequenceProb < 0.5) {
  291. confidenceThreshold = baseThreshold * 0.5;
  292. }
  293. this.logger.debug(`使用解码阈值: ${confidenceThreshold.toFixed(4)}`);
  294. for (let t = 0; t < seqLen; t++) {
  295. let maxProb = -1;
  296. let maxIndex = -1;
  297. // 找到当前时间步的最大概率字符
  298. for (let i = 0; i < vocabSize; i++) {
  299. const prob = data[t * vocabSize + i];
  300. if (prob > maxProb) {
  301. maxProb = prob;
  302. maxIndex = i;
  303. }
  304. }
  305. // 改进的CTC解码逻辑
  306. if (maxIndex > 0 && maxProb > confidenceThreshold) {
  307. const charIndex = maxIndex - 1;
  308. if (charIndex < this.characterSet.length) {
  309. const char = this.characterSet[charIndex];
  310. // 更智能的重复字符处理
  311. const shouldAddChar = maxIndex !== lastCharIndex ||
  312. maxProb > 0.8 ||
  313. (maxIndex === lastCharIndex && charCount > 0 && text[text.length - 1] !== char);
  314. if (shouldAddChar && char && char.trim() !== '') {
  315. text += char;
  316. confidenceSum += maxProb;
  317. charCount++;
  318. }
  319. lastCharIndex = maxIndex;
  320. } else {
  321. this.logger.warn(`字符索引${charIndex}超出字符集范围(0-${this.characterSet.length-1})`);
  322. }
  323. } else if (maxIndex === 0) {
  324. lastCharIndex = -1;
  325. }
  326. }
  327. const avgConfidence = charCount > 0 ? confidenceSum / charCount : 0;
  328. // 基本的文本清理(不包含错误模式修复)
  329. const cleanedText = this.basicTextCleaning(text);
  330. return {
  331. text: cleanedText,
  332. confidence: avgConfidence
  333. };
  334. }
  335. basicTextCleaning(text) {
  336. if (!text) return '';
  337. let cleaned = text;
  338. // 1. 移除过多的重复字符(保留合理的重复)
  339. cleaned = cleaned.replace(/([^0-9])\1{2,}/g, '$1$1');
  340. // 2. 修复标点符号
  341. cleaned = cleaned.replace(/∶/g, ':')
  342. .replace(/《/g, '(')
  343. .replace(/》/g, ')');
  344. // 3. 修复数字和百分号
  345. cleaned = cleaned.replace(/(\d+)%%/g, '$1%');
  346. return cleaned.trim();
  347. }
  348. }
  349. export default TextRecognizer;