textRecognizer.js 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. // server/utils/textRecognizer.js
  2. import { Tensor } from 'onnxruntime-node';
  3. import sharp from 'sharp';
  4. import fse from 'fs-extra';
  5. import * as path from 'path';
  6. class TextRecognizer {
  7. constructor() {
  8. this.recSession = null;
  9. this.config = null;
  10. this.characterSet = [];
  11. this.debugDir = path.join(process.cwd(), 'temp', 'debug');
  12. this.preprocessedDir = path.join(process.cwd(), 'temp', 'preprocessed');
  13. this.logger = {
  14. info: (msg, ...args) => console.log(`🔤 [识别] ${msg}`, ...args),
  15. error: (msg, ...args) => console.error(`❌ [识别] ${msg}`, ...args),
  16. debug: (msg, ...args) => console.log(`🐛 [识别] ${msg}`, ...args),
  17. warn: (msg, ...args) => console.warn(`🐛 [识别] ${msg}`, ...args)
  18. };
  19. // 确保目录存在
  20. fse.ensureDirSync(this.debugDir);
  21. fse.ensureDirSync(this.preprocessedDir);
  22. }
  23. initialize(recSession, config) {
  24. this.recSession = recSession;
  25. this.config = config;
  26. this.logger.info('文本识别器初始化完成');
  27. }
  28. async loadCharacterSet(keysPath) {
  29. try {
  30. const keysContent = await fse.readFile(keysPath, 'utf8');
  31. this.characterSet = [];
  32. const lines = keysContent.split('\n');
  33. // 使用提供的字符集文件
  34. const uniqueChars = new Set();
  35. for (const line of lines) {
  36. const trimmed = line.trim();
  37. // 跳过空行和注释行
  38. if (trimmed && !trimmed.startsWith('#')) {
  39. // 将每行作为一个完整的字符处理
  40. uniqueChars.add(trimmed);
  41. }
  42. }
  43. this.characterSet = Array.from(uniqueChars);
  44. if (this.characterSet.length === 0) {
  45. throw new Error('字符集文件为空或格式不正确');
  46. }
  47. this.logger.info(`字符集加载完成: ${this.characterSet.length}个字符`);
  48. // 记录字符集统计信息
  49. const charTypes = {
  50. chinese: 0,
  51. english: 0,
  52. digit: 0,
  53. punctuation: 0,
  54. other: 0
  55. };
  56. this.characterSet.forEach(char => {
  57. if (/[\u4e00-\u9fff]/.test(char)) {
  58. charTypes.chinese++;
  59. } else if (/[a-zA-Z]/.test(char)) {
  60. charTypes.english++;
  61. } else if (/[0-9]/.test(char)) {
  62. charTypes.digit++;
  63. } else if (/[,。!?;:""()【】《》…—·]/.test(char)) {
  64. charTypes.punctuation++;
  65. } else {
  66. charTypes.other++;
  67. }
  68. });
  69. this.logger.debug(`字符集统计: 中文${charTypes.chinese}, 英文${charTypes.english}, 数字${charTypes.digit}, 标点${charTypes.punctuation}, 其他${charTypes.other}`);
  70. this.logger.debug(`前20个字符: ${this.characterSet.slice(0, 20).join('')}`);
  71. } catch (error) {
  72. this.logger.error('加载字符集失败', error.message);
  73. // 完全使用提供的字符集,失败时抛出错误
  74. throw new Error(`字符集加载失败: ${error.message}`);
  75. }
  76. }
  77. getCharacterSetSize() {
  78. return this.characterSet.length;
  79. }
  80. async recognizeText(textRegionBuffer, regionIndex = 0) {
  81. const startTime = Date.now();
  82. this.logger.info(`开始文本识别 - 区域 ${regionIndex}`);
  83. try {
  84. const inputTensor = await this.prepareRecognitionInput(textRegionBuffer, regionIndex);
  85. const outputs = await this.recSession.run({ [this.recSession.inputNames[0]]: inputTensor });
  86. const result = this.postprocessRecognition(outputs);
  87. const processingTime = Date.now() - startTime;
  88. this.logger.info(`识别完成 - 区域 ${regionIndex}: "${result.text}", 置信度: ${result.confidence.toFixed(4)}, 耗时: ${processingTime}ms`);
  89. return result;
  90. } catch (error) {
  91. this.logger.error(`文本识别失败 - 区域 ${regionIndex}`, error);
  92. return { text: '', confidence: 0 };
  93. }
  94. }
  95. async prepareRecognitionInput(textRegionBuffer, regionIndex = 0) {
  96. this.logger.debug(`准备识别输入 - 区域 ${regionIndex}`);
  97. const targetHeight = 48;
  98. const targetWidth = 320; // 原始目标宽度
  99. const finalWidth = targetWidth + 20; // 最终宽度(左右各加10像素)
  100. const timestamp = Date.now();
  101. try {
  102. const metadata = await sharp(textRegionBuffer).metadata();
  103. this.logger.debug(`原始区域 ${regionIndex}: ${metadata.width}x${metadata.height}`);
  104. // 保存原始裁剪区域图像
  105. const originalPath = path.join(this.preprocessedDir, `region-${regionIndex}-original-${timestamp}.png`);
  106. await fse.writeFile(originalPath, textRegionBuffer);
  107. this.logger.debug(`保存原始区域图像: ${originalPath}`);
  108. // 图像分析
  109. const stats = await sharp(textRegionBuffer).grayscale().stats();
  110. const meanBrightness = stats.channels[0].mean;
  111. const stdDev = stats.channels[0].stdev;
  112. this.logger.debug(`图像统计 - 区域 ${regionIndex}: 亮度=${meanBrightness.toFixed(1)}, 对比度=${stdDev.toFixed(1)}`);
  113. // 智能预处理
  114. let processedBuffer = await this.applySmartPreprocessing(textRegionBuffer, meanBrightness, stdDev, regionIndex);
  115. // 保存预处理后的图像(灰度+对比度调整后)
  116. const processedPath = path.join(this.preprocessedDir, `region-${regionIndex}-processed-${timestamp}.png`);
  117. await fse.writeFile(processedPath, processedBuffer);
  118. this.logger.debug(`保存预处理图像: ${processedPath}`);
  119. // 保持宽高比的resize,并在左右添加10像素空白
  120. const resizedBuffer = await this.resizeWithAspectRatio(processedBuffer, targetWidth, targetHeight, regionIndex);
  121. // 保存调整大小后的图像
  122. const resizedPath = path.join(this.preprocessedDir, `region-${regionIndex}-resized-${timestamp}.png`);
  123. await fse.writeFile(resizedPath, resizedBuffer);
  124. this.logger.debug(`保存调整大小图像: ${resizedPath}`);
  125. // 使用最终尺寸创建张量
  126. const inputData = await this.bufferToTensor(resizedBuffer, finalWidth, targetHeight);
  127. this.logger.debug(`识别输入张量准备完成 - 区域 ${regionIndex}`);
  128. // 创建张量时使用最终尺寸
  129. return new Tensor('float32', inputData, [1, 3, targetHeight, finalWidth]);
  130. } catch (error) {
  131. this.logger.error(`准备识别输入失败 - 区域 ${regionIndex}`, error);
  132. return new Tensor('float32', new Float32Array(3 * targetHeight * finalWidth).fill(0.5), [1, 3, targetHeight, finalWidth]);
  133. }
  134. }
  135. // server/utils/textRecognizer.js
  136. // 增强的图像预处理
  137. async applySmartPreprocessing(buffer, meanBrightness, stdDev, regionIndex = 0) {
  138. let processedBuffer = buffer;
  139. try {
  140. // 更精细的图像分析
  141. const stats = await sharp(buffer)
  142. .grayscale()
  143. .stats();
  144. const median = stats.channels[0].median;
  145. const max = stats.channels[0].max;
  146. const min = stats.channels[0].min;
  147. this.logger.debug(`区域 ${regionIndex}: 详细统计 - 中值=${median}, 范围=${min}-${max}, 均值=${meanBrightness.toFixed(1)}, 标准差=${stdDev.toFixed(1)}`);
  148. // 更智能的预处理策略
  149. if (meanBrightness > 220 && stdDev < 25) {
  150. // 高亮度低对比度图像
  151. this.logger.debug(`区域 ${regionIndex}: 应用高亮度低对比度增强`);
  152. processedBuffer = await sharp(buffer)
  153. .linear(1.8, -80) // 更强的对比度增强
  154. .normalize({ lower: 5, upper: 95 }) // 更激进的归一化
  155. .grayscale()
  156. .toBuffer();
  157. } else if (meanBrightness < 70) {
  158. // 低亮度图像
  159. this.logger.debug(`区域 ${regionIndex}: 应用低亮度增强`);
  160. processedBuffer = await sharp(buffer)
  161. .linear(1.5, 50) // 更强的亮度提升
  162. .normalize()
  163. .grayscale()
  164. .toBuffer();
  165. } else if (stdDev < 15) {
  166. // 极低对比度
  167. this.logger.debug(`区域 ${regionIndex}: 应用极低对比度增强`);
  168. processedBuffer = await sharp(buffer)
  169. .linear(2.0, -30) // 非常强的对比度增强
  170. .normalize({ lower: 1, upper: 99 })
  171. .grayscale()
  172. .toBuffer();
  173. } else if (stdDev > 80) {
  174. // 高对比度图像,可能过度增强
  175. this.logger.debug(`区域 ${regionIndex}: 应用高对比度抑制`);
  176. processedBuffer = await sharp(buffer)
  177. .linear(0.8, 20) // 降低对比度
  178. .normalize()
  179. .grayscale()
  180. .toBuffer();
  181. } else {
  182. // 标准处理
  183. this.logger.debug(`区域 ${regionIndex}: 应用标准增强`);
  184. processedBuffer = await sharp(buffer)
  185. .linear(1.3, -15) // 适度的对比度增强
  186. .normalize({ lower: 10, upper: 90 })
  187. .grayscale()
  188. .toBuffer();
  189. }
  190. // 应用锐化滤波增强文字边缘
  191. processedBuffer = await sharp(processedBuffer)
  192. .sharpen({
  193. sigma: 1.2,
  194. m1: 1.5,
  195. m2: 0.7
  196. })
  197. .toBuffer();
  198. } catch (error) {
  199. this.logger.error(`区域 ${regionIndex}: 预处理失败`, error);
  200. // 回退到基本处理
  201. processedBuffer = await sharp(buffer)
  202. .normalize()
  203. .grayscale()
  204. .toBuffer();
  205. }
  206. return processedBuffer;
  207. }
  208. async resizeWithAspectRatio(buffer, targetWidth, targetHeight, regionIndex = 0) {
  209. const metadata = await sharp(buffer).metadata();
  210. const originalAspectRatio = metadata.width / metadata.height;
  211. const targetAspectRatio = targetWidth / targetHeight;
  212. let resizeWidth, resizeHeight;
  213. if (originalAspectRatio > targetAspectRatio) {
  214. // 宽度限制,按宽度缩放
  215. resizeWidth = targetWidth;
  216. resizeHeight = Math.round(targetWidth / originalAspectRatio);
  217. } else {
  218. // 高度限制,按高度缩放
  219. resizeHeight = targetHeight;
  220. resizeWidth = Math.round(targetHeight * originalAspectRatio);
  221. }
  222. resizeWidth = Math.max(1, Math.min(resizeWidth, targetWidth));
  223. resizeHeight = Math.max(1, Math.min(resizeHeight, targetHeight));
  224. this.logger.debug(`区域 ${regionIndex}: 调整尺寸 ${metadata.width}x${metadata.height} -> ${resizeWidth}x${resizeHeight}`);
  225. // 计算居中的偏移量
  226. const offsetX = Math.floor((targetWidth - resizeWidth) / 2);
  227. const offsetY = Math.floor((targetHeight - resizeHeight) / 2);
  228. this.logger.debug(`区域 ${regionIndex}: 居中偏移 X=${offsetX}, Y=${offsetY}`);
  229. // 先调整大小并居中
  230. let resizedBuffer = await sharp(buffer)
  231. .resize(resizeWidth, resizeHeight, {
  232. fit: 'contain',
  233. background: { r: 255, g: 255, b: 255 }
  234. })
  235. .extend({
  236. top: offsetY,
  237. bottom: targetHeight - resizeHeight - offsetY,
  238. left: offsetX,
  239. right: targetWidth - resizeWidth - offsetX,
  240. background: { r: 255, g: 255, b: 255 }
  241. })
  242. .png()
  243. .toBuffer();
  244. // 在左右各添加10像素空白
  245. const finalWidth = targetWidth + 20; // 左右各加10像素
  246. const finalHeight = targetHeight;
  247. resizedBuffer = await sharp(resizedBuffer)
  248. .extend({
  249. top: 0,
  250. bottom: 0,
  251. left: 10,
  252. right: 10,
  253. background: { r: 255, g: 255, b: 255 }
  254. })
  255. .png()
  256. .toBuffer();
  257. this.logger.debug(`区域 ${regionIndex}: 最终尺寸 ${finalWidth}x${finalHeight} (左右各加10像素空白)`);
  258. return resizedBuffer;
  259. }
  260. async bufferToTensor(buffer, width, height) {
  261. // 获取实际图像尺寸(因为现在宽度增加了20像素)
  262. const metadata = await sharp(buffer).metadata();
  263. const actualWidth = metadata.width;
  264. const actualHeight = metadata.height;
  265. const imageData = await sharp(buffer)
  266. .ensureAlpha()
  267. .raw()
  268. .toBuffer({ resolveWithObject: true });
  269. // 使用实际尺寸创建张量
  270. const inputData = new Float32Array(3 * actualHeight * actualWidth);
  271. const data = imageData.data;
  272. for (let i = 0; i < data.length; i += 4) {
  273. const pixelIndex = Math.floor(i / 4);
  274. const y = Math.floor(pixelIndex / actualWidth);
  275. const x = pixelIndex % actualWidth;
  276. // 使用灰度值填充三个通道
  277. const grayValue = data[i] / 255.0;
  278. for (let c = 0; c < 3; c++) {
  279. const inputIndex = c * actualHeight * actualWidth + y * actualWidth + x;
  280. if (inputIndex < inputData.length) {
  281. inputData[inputIndex] = grayValue;
  282. }
  283. }
  284. }
  285. return inputData;
  286. }
  287. postprocessRecognition(outputs) {
  288. this.logger.debug('开始识别后处理');
  289. try {
  290. const outputNames = this.recSession.outputNames;
  291. const recognitionOutput = outputs[outputNames[0]];
  292. if (!recognitionOutput) {
  293. this.logger.debug('识别输出为空');
  294. return { text: '', confidence: 0 };
  295. }
  296. const data = recognitionOutput.data;
  297. const [batch, seqLen, vocabSize] = recognitionOutput.dims;
  298. this.logger.debug(`序列长度: ${seqLen}, 词汇表大小: ${vocabSize}, 字符集大小: ${this.characterSet.length}`);
  299. if (this.characterSet.length === 0) {
  300. this.logger.error('字符集为空');
  301. return { text: '', confidence: 0 };
  302. }
  303. // 验证词汇表大小与字符集大小的匹配
  304. if (vocabSize !== this.characterSet.length + 1) {
  305. this.logger.warn(`词汇表大小(${vocabSize})与字符集大小(${this.characterSet.length})不匹配,可能影响识别效果`);
  306. }
  307. const { text, confidence } = this.ctcDecode(data, seqLen, vocabSize);
  308. this.logger.debug(`解码结果: "${text}", 置信度: ${confidence.toFixed(4)}`);
  309. return { text, confidence };
  310. } catch (error) {
  311. this.logger.error('识别后处理失败', error);
  312. return { text: '', confidence: 0 };
  313. }
  314. }
  315. ctcDecode(data, seqLen, vocabSize) {
  316. let text = '';
  317. let lastCharIndex = -1;
  318. let confidenceSum = 0;
  319. let charCount = 0;
  320. // 动态阈值调整
  321. const baseThreshold = 0.03;
  322. let confidenceThreshold = baseThreshold;
  323. // 分析序列置信度分布
  324. let maxSequenceProb = 0;
  325. let minSequenceProb = 1;
  326. let sumProb = 0;
  327. let probCount = 0;
  328. for (let t = 0; t < seqLen; t++) {
  329. for (let i = 0; i < vocabSize; i++) {
  330. const prob = data[t * vocabSize + i];
  331. if (prob > 0.01) { // 只统计有意义的概率
  332. maxSequenceProb = Math.max(maxSequenceProb, prob);
  333. minSequenceProb = Math.min(minSequenceProb, prob);
  334. sumProb += prob;
  335. probCount++;
  336. }
  337. }
  338. }
  339. const avgProb = probCount > 0 ? sumProb / probCount : 0;
  340. // 根据序列特性动态调整阈值
  341. if (avgProb < 0.3) {
  342. confidenceThreshold = baseThreshold * 0.5;
  343. } else if (avgProb > 0.7) {
  344. confidenceThreshold = baseThreshold * 1.5;
  345. }
  346. this.logger.debug(`序列统计: 平均概率=${avgProb.toFixed(4)}, 使用解码阈值: ${confidenceThreshold.toFixed(4)}`);
  347. // 改进的beam search算法
  348. const beamWidth = 5;
  349. let beams = [{ text: '', confidence: 1.0, lastChar: -1 }];
  350. for (let t = 0; t < seqLen; t++) {
  351. const newBeams = [];
  352. // 获取当前时间步的top-k字符
  353. const topK = [];
  354. for (let i = 0; i < vocabSize; i++) {
  355. const prob = data[t * vocabSize + i];
  356. if (prob > confidenceThreshold) {
  357. topK.push({ index: i, prob });
  358. }
  359. }
  360. // 按概率排序
  361. topK.sort((a, b) => b.prob - a.prob);
  362. const candidates = topK.slice(0, beamWidth);
  363. // 为每个beam扩展候选字符
  364. for (const beam of beams) {
  365. for (const candidate of candidates) {
  366. const charIndex = candidate.index;
  367. if (charIndex === 0) {
  368. // 空白字符
  369. newBeams.push({
  370. text: beam.text,
  371. confidence: beam.confidence,
  372. lastChar: -1
  373. });
  374. } else {
  375. const actualCharIndex = charIndex - 1;
  376. if (actualCharIndex < this.characterSet.length) {
  377. const char = this.characterSet[actualCharIndex];
  378. let newText = beam.text;
  379. // 处理重复字符
  380. if (charIndex !== beam.lastChar) {
  381. newText += char;
  382. }
  383. newBeams.push({
  384. text: newText,
  385. confidence: beam.confidence * candidate.prob,
  386. lastChar: charIndex
  387. });
  388. }
  389. }
  390. }
  391. }
  392. // 选择top beamWidth个beam
  393. newBeams.sort((a, b) => b.confidence - a.confidence);
  394. beams = newBeams.slice(0, beamWidth);
  395. }
  396. // 选择最佳beam
  397. if (beams.length > 0) {
  398. const bestBeam = beams[0];
  399. text = bestBeam.text;
  400. // 计算平均置信度(几何平均)
  401. const textLength = text.length;
  402. if (textLength > 0) {
  403. confidenceSum = Math.pow(bestBeam.confidence, 1 / textLength);
  404. charCount = textLength;
  405. }
  406. }
  407. const avgConfidence = charCount > 0 ? confidenceSum : 0;
  408. return {
  409. text: text,
  410. confidence: avgConfidence
  411. };
  412. }
  413. }
  414. export default TextRecognizer;