onnxOcrManager.js 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. // server/utils/onnxOcrManager.js
  2. import { InferenceSession } from 'onnxruntime-node';
  3. import sharp from 'sharp';
  4. import fse from 'fs-extra';
  5. import * as path from 'path';
  6. import { fileURLToPath } from 'url';
  7. import DetectionProcessor from './detectionProcessor.js';
  8. import RecognitionProcessor from './recognitionProcessor.js';
  9. import ImagePreprocessor from './imagePreprocessor.js';
  10. import TextPostProcessor from './textPostProcessor.js';
  11. const __dirname = path.dirname(fileURLToPath(import.meta.url));
  12. class OnnxOcrManager {
  13. constructor() {
  14. this.detSession = null;
  15. this.recSession = null;
  16. this.clsSession = null;
  17. this.isInitialized = false;
  18. this.modelDir = path.join(process.cwd(), 'models', 'ocr');
  19. this.detModelPath = path.join(this.modelDir, 'Det', '中文_OCRv3.onnx');
  20. this.recModelPath = path.join(this.modelDir, 'Rec', '中文简体_OCRv3.onnx');
  21. this.clsModelPath = path.join(this.modelDir, 'Cls', '原始分类器模型.onnx');
  22. this.keysPath = path.join(this.modelDir, 'Keys', '中文简体_OCRv3.txt');
  23. this.detectionProcessor = new DetectionProcessor();
  24. this.recognitionProcessor = new RecognitionProcessor();
  25. this.imagePreprocessor = new ImagePreprocessor();
  26. this.textPostProcessor = new TextPostProcessor();
  27. this.logger = {
  28. info: (msg, ...args) => console.log(`🚀 [OCR管理器] ${msg}`, ...args),
  29. error: (msg, ...args) => console.error(`❌ [OCR管理器] ${msg}`, ...args),
  30. debug: (msg, ...args) => console.log(`🐛 [OCR管理器] ${msg}`, ...args)
  31. };
  32. // 确保可视化目录存在
  33. this.visualizationDir = path.join(process.cwd(), 'temp', 'visualization');
  34. fse.ensureDirSync(this.visualizationDir);
  35. // 优化配置参数
  36. this.defaultConfig = {
  37. language: 'ch',
  38. detLimitSideLen: 960,
  39. detThresh: 0.05,
  40. detBoxThresh: 0.08,
  41. detUnclipRatio: 1.8,
  42. maxTextLength: 100,
  43. recImageHeight: 48,
  44. clsThresh: 0.7,
  45. minTextHeight: 1,
  46. minTextWidth: 1,
  47. clusterDistance: 8,
  48. minClusterPoints: 1
  49. };
  50. }
  51. async initialize(config = {}) {
  52. if (this.isInitialized) {
  53. this.logger.info('OCR管理器已初始化');
  54. return;
  55. }
  56. try {
  57. this.logger.info('开始初始化OCR管理器...');
  58. await this.validateModelFiles();
  59. await this.recognitionProcessor.loadCharacterSet(this.keysPath);
  60. const [detSession, recSession, clsSession] = await Promise.all([
  61. InferenceSession.create(this.detModelPath, { executionProviders: ['cpu'] }),
  62. InferenceSession.create(this.recModelPath, { executionProviders: ['cpu'] }),
  63. InferenceSession.create(this.clsModelPath, { executionProviders: ['cpu'] })
  64. ]);
  65. this.detSession = detSession;
  66. this.recSession = recSession;
  67. this.clsSession = clsSession;
  68. const mergedConfig = { ...this.defaultConfig, ...config };
  69. this.detectionProcessor.initialize(this.detSession, mergedConfig);
  70. this.recognitionProcessor.initialize(this.recSession, this.clsSession, mergedConfig);
  71. this.isInitialized = true;
  72. this.logger.info('OCR管理器初始化完成');
  73. } catch (error) {
  74. this.logger.error('初始化失败', error);
  75. throw error;
  76. }
  77. }
  78. async validateModelFiles() {
  79. const requiredFiles = [
  80. { path: this.detModelPath, name: '检测模型' },
  81. { path: this.recModelPath, name: '识别模型' },
  82. { path: this.clsModelPath, name: '分类模型' },
  83. { path: this.keysPath, name: '字符集文件' }
  84. ];
  85. for (const { path: filePath, name } of requiredFiles) {
  86. const exists = await fse.pathExists(filePath);
  87. if (!exists) {
  88. throw new Error(`模型文件不存在: ${filePath}`);
  89. }
  90. this.logger.debug(`验证通过: ${name}`);
  91. }
  92. this.logger.info('所有模型文件验证通过');
  93. }
  94. async recognizeImage(imagePath, config = {}) {
  95. if (!this.isInitialized) {
  96. await this.initialize(config);
  97. }
  98. if (!imagePath || typeof imagePath !== 'string') {
  99. throw new Error(`无效的图片路径: ${imagePath}`);
  100. }
  101. if (!fse.existsSync(imagePath)) {
  102. throw new Error(`图片文件不存在: ${imagePath}`);
  103. }
  104. try {
  105. this.logger.info(`开始OCR识别: ${path.basename(imagePath)}`);
  106. const startTime = Date.now();
  107. const preprocessResult = await this.imagePreprocessor.preprocessWithPadding(imagePath, config);
  108. const { processedImage } = preprocessResult;
  109. const textBoxes = await this.detectionProcessor.detectText(processedImage);
  110. // 在原始图像上绘制文本框
  111. await this.drawTextBoxesOnOriginalImage(imagePath, textBoxes, processedImage);
  112. const recognitionResults = await this.recognitionProcessor.recognizeTextWithCls(processedImage, textBoxes);
  113. const processingTime = Date.now() - startTime;
  114. const textBlocks = this.textPostProcessor.buildTextBlocks(recognitionResults);
  115. const imageInfo = await this.imagePreprocessor.getImageInfo(imagePath);
  116. const rawText = textBlocks.map(block => block.content).join('\n');
  117. const overallConfidence = this.textPostProcessor.calculateOverallConfidence(recognitionResults);
  118. const result = {
  119. textBlocks,
  120. confidence: overallConfidence,
  121. processingTime,
  122. isOffline: true,
  123. imagePath,
  124. totalPages: 1,
  125. rawText,
  126. imageInfo,
  127. recognitionCount: recognitionResults.length,
  128. detectionCount: textBoxes.length,
  129. visualizationPath: this.getVisualizationPath(imagePath)
  130. };
  131. this.logger.info(`OCR识别完成:
  132. - 处理时间: ${processingTime}ms
  133. - 检测区域: ${textBoxes.length}个
  134. - 成功识别: ${recognitionResults.length}个
  135. - 总体置信度: ${overallConfidence.toFixed(4)}
  136. - 最终文本: ${rawText.length}字符
  137. - 可视化图像: ${result.visualizationPath}`);
  138. return result;
  139. } catch (error) {
  140. this.logger.error(`OCR识别失败: ${error.message}`);
  141. throw new Error(`OCR识别失败: ${error.message}`);
  142. }
  143. }
  144. async drawTextBoxesOnOriginalImage(originalImagePath, textBoxes, processedImage) {
  145. try {
  146. this.logger.info('开始在原始图像上绘制文本框');
  147. // 读取原始图像
  148. const originalImage = sharp(originalImagePath);
  149. const metadata = await originalImage.metadata();
  150. // 创建SVG绘制指令
  151. const svgOverlay = this.createTextBoxesSVG(textBoxes, processedImage, metadata);
  152. // 将SVG叠加到原始图像上
  153. const visualizationPath = this.getVisualizationPath(originalImagePath);
  154. await originalImage
  155. .composite([{
  156. input: Buffer.from(svgOverlay),
  157. top: 0,
  158. left: 0
  159. }])
  160. .png()
  161. .toFile(visualizationPath);
  162. this.logger.info(`文本框可视化图像已保存: ${visualizationPath}`);
  163. } catch (error) {
  164. this.logger.error('绘制文本框失败', error);
  165. }
  166. }
  167. createTextBoxesSVG(textBoxes, processedImage, originalMetadata) {
  168. const { width, height } = originalMetadata;
  169. let svg = `<svg width="${width}" height="${height}" xmlns="http://www.w3.org/2000/svg">`;
  170. // 定义样式
  171. svg += `
  172. <style>
  173. .text-box {
  174. fill: none;
  175. stroke: #ff0000;
  176. stroke-width: 2;
  177. }
  178. .text-box-high-conf {
  179. fill: none;
  180. stroke: #00ff00;
  181. stroke-width: 2;
  182. }
  183. .text-label {
  184. font-size: 12px;
  185. fill: #ff0000;
  186. font-family: Arial, sans-serif;
  187. }
  188. </style>
  189. `;
  190. textBoxes.forEach((box, index) => {
  191. // 将处理后的图像坐标转换回原始图像坐标
  192. const originalBox = this.scaleBoxToOriginalImage(box, processedImage);
  193. // 根据置信度选择颜色
  194. const boxClass = box.confidence > 0.8 ? 'text-box-high-conf' : 'text-box';
  195. // 绘制文本框(多边形)
  196. const points = [
  197. `${originalBox.x1},${originalBox.y1}`,
  198. `${originalBox.x2},${originalBox.y2}`,
  199. `${originalBox.x3},${originalBox.y3}`,
  200. `${originalBox.x4},${originalBox.y4}`
  201. ].join(' ');
  202. svg += `<polygon class="${boxClass}" points="${points}" />`;
  203. // 在框上方添加索引和置信度标签
  204. const labelX = Math.min(originalBox.x1, originalBox.x2, originalBox.x3, originalBox.x4);
  205. const labelY = Math.min(originalBox.y1, originalBox.y2, originalBox.y3, originalBox.y4) - 5;
  206. if (labelY > 15) { // 确保标签在图像范围内
  207. svg += `<text class="text-label" x="${labelX}" y="${labelY}">${index + 1} (${box.confidence.toFixed(2)})</text>`;
  208. }
  209. });
  210. svg += '</svg>';
  211. return svg;
  212. }
  213. scaleBoxToOriginalImage(box, processedImage) {
  214. const {
  215. scaleX, scaleY,
  216. paddingX, paddingY,
  217. originalWidth, originalHeight
  218. } = processedImage;
  219. // 将处理后的图像坐标转换回填充后的图像坐标
  220. const paddedX1 = box.x1 * scaleX;
  221. const paddedY1 = box.y1 * scaleY;
  222. const paddedX2 = box.x2 * scaleX;
  223. const paddedY2 = box.y2 * scaleY;
  224. const paddedX3 = box.x3 * scaleX;
  225. const paddedY3 = box.y3 * scaleY;
  226. const paddedX4 = box.x4 * scaleX;
  227. const paddedY4 = box.y4 * scaleY;
  228. // 去除填充,得到原始图像坐标
  229. const originalX1 = paddedX1 - paddingX;
  230. const originalY1 = paddedY1 - paddingY;
  231. const originalX2 = paddedX2 - paddingX;
  232. const originalY2 = paddedY2 - paddingY;
  233. const originalX3 = paddedX3 - paddingX;
  234. const originalY3 = paddedY3 - paddingY;
  235. const originalX4 = paddedX4 - paddingX;
  236. const originalY4 = paddedY4 - paddingY;
  237. const clamp = (value, max) => Math.max(0, Math.min(max, value));
  238. return {
  239. x1: clamp(originalX1, originalWidth - 1),
  240. y1: clamp(originalY1, originalHeight - 1),
  241. x2: clamp(originalX2, originalWidth - 1),
  242. y2: clamp(originalY2, originalHeight - 1),
  243. x3: clamp(originalX3, originalWidth - 1),
  244. y3: clamp(originalY3, originalHeight - 1),
  245. x4: clamp(originalX4, originalWidth - 1),
  246. y4: clamp(originalY4, originalHeight - 1),
  247. confidence: box.confidence
  248. };
  249. }
  250. getVisualizationPath(originalImagePath) {
  251. const originalName = path.basename(originalImagePath, path.extname(originalImagePath));
  252. const timestamp = Date.now();
  253. return path.join(this.visualizationDir, `${originalName}-detection-${timestamp}.png`);
  254. }
  255. getStatus() {
  256. return {
  257. isInitialized: this.isInitialized,
  258. isOffline: true,
  259. engine: 'PP-OCRv3 (ONNX Runtime)',
  260. version: '2.0.0',
  261. models: {
  262. detection: path.relative(process.cwd(), this.detModelPath),
  263. recognition: path.relative(process.cwd(), this.recModelPath),
  264. classification: path.relative(process.cwd(), this.clsModelPath),
  265. characterSet: this.recognitionProcessor.getCharacterSetSize()
  266. },
  267. config: {
  268. detThresh: this.defaultConfig.detThresh,
  269. detBoxThresh: this.defaultConfig.detBoxThresh,
  270. clsThresh: this.defaultConfig.clsThresh,
  271. preprocessing: 'enhanced with smart padding'
  272. },
  273. backend: 'CPU'
  274. };
  275. }
  276. }
  277. const onnxOcrManager = new OnnxOcrManager();
  278. export default onnxOcrManager;