// server/utils/onnxOcrManager.js import { InferenceSession } from 'onnxruntime-node'; import sharp from 'sharp'; import fse from 'fs-extra'; import * as path from 'path'; import { fileURLToPath } from 'url'; import DetectionProcessor from './detectionProcessor.js'; import RecognitionProcessor from './recognitionProcessor.js'; import ImagePreprocessor from './imagePreprocessor.js'; import TextPostProcessor from './textPostProcessor.js'; const __dirname = path.dirname(fileURLToPath(import.meta.url)); class OnnxOcrManager { constructor() { this.detSession = null; this.recSession = null; this.clsSession = null; this.isInitialized = false; this.modelDir = path.join(process.cwd(), 'models', 'ocr'); this.detModelPath = path.join(this.modelDir, 'Det', '中文_OCRv3.onnx'); this.recModelPath = path.join(this.modelDir, 'Rec', '中文简体_OCRv3.onnx'); this.clsModelPath = path.join(this.modelDir, 'Cls', '原始分类器模型.onnx'); this.keysPath = path.join(this.modelDir, 'Keys', '中文简体_OCRv3.txt'); this.detectionProcessor = new DetectionProcessor(); this.recognitionProcessor = new RecognitionProcessor(); this.imagePreprocessor = new ImagePreprocessor(); this.textPostProcessor = new TextPostProcessor(); this.logger = { info: (msg, ...args) => console.log(`🚀 [OCR管理器] ${msg}`, ...args), error: (msg, ...args) => console.error(`❌ [OCR管理器] ${msg}`, ...args), debug: (msg, ...args) => console.log(`🐛 [OCR管理器] ${msg}`, ...args) }; // 确保可视化目录存在 this.visualizationDir = path.join(process.cwd(), 'temp', 'visualization'); fse.ensureDirSync(this.visualizationDir); // 优化配置参数 this.defaultConfig = { language: 'ch', detLimitSideLen: 960, detThresh: 0.05, detBoxThresh: 0.08, detUnclipRatio: 1.8, maxTextLength: 100, recImageHeight: 48, clsThresh: 0.7, minTextHeight: 1, minTextWidth: 1, clusterDistance: 8, minClusterPoints: 1 }; } async initialize(config = {}) { if (this.isInitialized) { this.logger.info('OCR管理器已初始化'); return; } try { this.logger.info('开始初始化OCR管理器...'); await this.validateModelFiles(); await this.recognitionProcessor.loadCharacterSet(this.keysPath); const [detSession, recSession, clsSession] = await Promise.all([ InferenceSession.create(this.detModelPath, { executionProviders: ['cpu'] }), InferenceSession.create(this.recModelPath, { executionProviders: ['cpu'] }), InferenceSession.create(this.clsModelPath, { executionProviders: ['cpu'] }) ]); this.detSession = detSession; this.recSession = recSession; this.clsSession = clsSession; const mergedConfig = { ...this.defaultConfig, ...config }; this.detectionProcessor.initialize(this.detSession, mergedConfig); this.recognitionProcessor.initialize(this.recSession, this.clsSession, mergedConfig); this.isInitialized = true; this.logger.info('OCR管理器初始化完成'); } catch (error) { this.logger.error('初始化失败', error); throw error; } } async validateModelFiles() { const requiredFiles = [ { path: this.detModelPath, name: '检测模型' }, { path: this.recModelPath, name: '识别模型' }, { path: this.clsModelPath, name: '分类模型' }, { path: this.keysPath, name: '字符集文件' } ]; for (const { path: filePath, name } of requiredFiles) { const exists = await fse.pathExists(filePath); if (!exists) { throw new Error(`模型文件不存在: ${filePath}`); } this.logger.debug(`验证通过: ${name}`); } this.logger.info('所有模型文件验证通过'); } async recognizeImage(imagePath, config = {}) { if (!this.isInitialized) { await this.initialize(config); } if (!imagePath || typeof imagePath !== 'string') { throw new Error(`无效的图片路径: ${imagePath}`); } if (!fse.existsSync(imagePath)) { throw new Error(`图片文件不存在: ${imagePath}`); } try { this.logger.info(`开始OCR识别: ${path.basename(imagePath)}`); const startTime = Date.now(); const preprocessResult = await this.imagePreprocessor.preprocessWithPadding(imagePath, config); const { processedImage } = preprocessResult; const textBoxes = await this.detectionProcessor.detectText(processedImage); // 在原始图像上绘制文本框 await this.drawTextBoxesOnOriginalImage(imagePath, textBoxes, processedImage); const recognitionResults = await this.recognitionProcessor.recognizeTextWithCls(processedImage, textBoxes); const processingTime = Date.now() - startTime; const textBlocks = this.textPostProcessor.buildTextBlocks(recognitionResults); const imageInfo = await this.imagePreprocessor.getImageInfo(imagePath); const rawText = textBlocks.map(block => block.content).join('\n'); const overallConfidence = this.textPostProcessor.calculateOverallConfidence(recognitionResults); const result = { textBlocks, confidence: overallConfidence, processingTime, isOffline: true, imagePath, totalPages: 1, rawText, imageInfo, recognitionCount: recognitionResults.length, detectionCount: textBoxes.length, visualizationPath: this.getVisualizationPath(imagePath) }; this.logger.info(`OCR识别完成: - 处理时间: ${processingTime}ms - 检测区域: ${textBoxes.length}个 - 成功识别: ${recognitionResults.length}个 - 总体置信度: ${overallConfidence.toFixed(4)} - 最终文本: ${rawText.length}字符 - 可视化图像: ${result.visualizationPath}`); return result; } catch (error) { this.logger.error(`OCR识别失败: ${error.message}`); throw new Error(`OCR识别失败: ${error.message}`); } } async drawTextBoxesOnOriginalImage(originalImagePath, textBoxes, processedImage) { try { this.logger.info('开始在原始图像上绘制文本框'); // 读取原始图像 const originalImage = sharp(originalImagePath); const metadata = await originalImage.metadata(); // 创建SVG绘制指令 const svgOverlay = this.createTextBoxesSVG(textBoxes, processedImage, metadata); // 将SVG叠加到原始图像上 const visualizationPath = this.getVisualizationPath(originalImagePath); await originalImage .composite([{ input: Buffer.from(svgOverlay), top: 0, left: 0 }]) .png() .toFile(visualizationPath); this.logger.info(`文本框可视化图像已保存: ${visualizationPath}`); } catch (error) { this.logger.error('绘制文本框失败', error); } } createTextBoxesSVG(textBoxes, processedImage, originalMetadata) { const { width, height } = originalMetadata; let svg = ``; // 定义样式 svg += ` `; textBoxes.forEach((box, index) => { // 将处理后的图像坐标转换回原始图像坐标 const originalBox = this.scaleBoxToOriginalImage(box, processedImage); // 根据置信度选择颜色 const boxClass = box.confidence > 0.8 ? 'text-box-high-conf' : 'text-box'; // 绘制文本框(多边形) const points = [ `${originalBox.x1},${originalBox.y1}`, `${originalBox.x2},${originalBox.y2}`, `${originalBox.x3},${originalBox.y3}`, `${originalBox.x4},${originalBox.y4}` ].join(' '); svg += ``; // 在框上方添加索引和置信度标签 const labelX = Math.min(originalBox.x1, originalBox.x2, originalBox.x3, originalBox.x4); const labelY = Math.min(originalBox.y1, originalBox.y2, originalBox.y3, originalBox.y4) - 5; if (labelY > 15) { // 确保标签在图像范围内 svg += `${index + 1} (${box.confidence.toFixed(2)})`; } }); svg += ''; return svg; } scaleBoxToOriginalImage(box, processedImage) { const { scaleX, scaleY, paddingX, paddingY, originalWidth, originalHeight } = processedImage; // 将处理后的图像坐标转换回填充后的图像坐标 const paddedX1 = box.x1 * scaleX; const paddedY1 = box.y1 * scaleY; const paddedX2 = box.x2 * scaleX; const paddedY2 = box.y2 * scaleY; const paddedX3 = box.x3 * scaleX; const paddedY3 = box.y3 * scaleY; const paddedX4 = box.x4 * scaleX; const paddedY4 = box.y4 * scaleY; // 去除填充,得到原始图像坐标 const originalX1 = paddedX1 - paddingX; const originalY1 = paddedY1 - paddingY; const originalX2 = paddedX2 - paddingX; const originalY2 = paddedY2 - paddingY; const originalX3 = paddedX3 - paddingX; const originalY3 = paddedY3 - paddingY; const originalX4 = paddedX4 - paddingX; const originalY4 = paddedY4 - paddingY; const clamp = (value, max) => Math.max(0, Math.min(max, value)); return { x1: clamp(originalX1, originalWidth - 1), y1: clamp(originalY1, originalHeight - 1), x2: clamp(originalX2, originalWidth - 1), y2: clamp(originalY2, originalHeight - 1), x3: clamp(originalX3, originalWidth - 1), y3: clamp(originalY3, originalHeight - 1), x4: clamp(originalX4, originalWidth - 1), y4: clamp(originalY4, originalHeight - 1), confidence: box.confidence }; } getVisualizationPath(originalImagePath) { const originalName = path.basename(originalImagePath, path.extname(originalImagePath)); const timestamp = Date.now(); return path.join(this.visualizationDir, `${originalName}-detection-${timestamp}.png`); } getStatus() { return { isInitialized: this.isInitialized, isOffline: true, engine: 'PP-OCRv3 (ONNX Runtime)', version: '2.0.0', models: { detection: path.relative(process.cwd(), this.detModelPath), recognition: path.relative(process.cwd(), this.recModelPath), classification: path.relative(process.cwd(), this.clsModelPath), characterSet: this.recognitionProcessor.getCharacterSetSize() }, config: { detThresh: this.defaultConfig.detThresh, detBoxThresh: this.defaultConfig.detBoxThresh, clsThresh: this.defaultConfig.clsThresh, preprocessing: 'enhanced with smart padding' }, backend: 'CPU' }; } } const onnxOcrManager = new OnnxOcrManager(); export default onnxOcrManager;