Electron-vue3-ts-offline/server/utils/onnxOcrManager.js
2025-11-13 18:09:31 +08:00

336 行
12 KiB
JavaScript

// server/utils/onnxOcrManager.js
import { InferenceSession } from 'onnxruntime-node';
import sharp from 'sharp';
import fse from 'fs-extra';
import * as path from 'path';
import { fileURLToPath } from 'url';
import DetectionProcessor from './detectionProcessor.js';
import RecognitionProcessor from './recognitionProcessor.js';
import ImagePreprocessor from './imagePreprocessor.js';
import TextPostProcessor from './textPostProcessor.js';
const __dirname = path.dirname(fileURLToPath(import.meta.url));
class OnnxOcrManager {
constructor() {
this.detSession = null;
this.recSession = null;
this.clsSession = null;
this.isInitialized = false;
this.modelDir = path.join(process.cwd(), 'models', 'ocr');
this.detModelPath = path.join(this.modelDir, 'Det', '中文_OCRv3.onnx');
this.recModelPath = path.join(this.modelDir, 'Rec', '中文简体_OCRv3.onnx');
this.clsModelPath = path.join(this.modelDir, 'Cls', '原始分类器模型.onnx');
this.keysPath = path.join(this.modelDir, 'Keys', '中文简体_OCRv3.txt');
this.detectionProcessor = new DetectionProcessor();
this.recognitionProcessor = new RecognitionProcessor();
this.imagePreprocessor = new ImagePreprocessor();
this.textPostProcessor = new TextPostProcessor();
this.logger = {
info: (msg, ...args) => console.log(`🚀 [OCR管理器] ${msg}`, ...args),
error: (msg, ...args) => console.error(`❌ [OCR管理器] ${msg}`, ...args),
debug: (msg, ...args) => console.log(`🐛 [OCR管理器] ${msg}`, ...args)
};
// 确保可视化目录存在
this.visualizationDir = path.join(process.cwd(), 'temp', 'visualization');
fse.ensureDirSync(this.visualizationDir);
// 优化配置参数
this.defaultConfig = {
language: 'ch',
detLimitSideLen: 960,
detThresh: 0.05,
detBoxThresh: 0.08,
detUnclipRatio: 1.8,
maxTextLength: 100,
recImageHeight: 48,
clsThresh: 0.7,
minTextHeight: 1,
minTextWidth: 1,
clusterDistance: 8,
minClusterPoints: 1
};
}
async initialize(config = {}) {
if (this.isInitialized) {
this.logger.info('OCR管理器已初始化');
return;
}
try {
this.logger.info('开始初始化OCR管理器...');
await this.validateModelFiles();
await this.recognitionProcessor.loadCharacterSet(this.keysPath);
const [detSession, recSession, clsSession] = await Promise.all([
InferenceSession.create(this.detModelPath, { executionProviders: ['cpu'] }),
InferenceSession.create(this.recModelPath, { executionProviders: ['cpu'] }),
InferenceSession.create(this.clsModelPath, { executionProviders: ['cpu'] })
]);
this.detSession = detSession;
this.recSession = recSession;
this.clsSession = clsSession;
const mergedConfig = { ...this.defaultConfig, ...config };
this.detectionProcessor.initialize(this.detSession, mergedConfig);
this.recognitionProcessor.initialize(this.recSession, this.clsSession, mergedConfig);
this.isInitialized = true;
this.logger.info('OCR管理器初始化完成');
} catch (error) {
this.logger.error('初始化失败', error);
throw error;
}
}
async validateModelFiles() {
const requiredFiles = [
{ path: this.detModelPath, name: '检测模型' },
{ path: this.recModelPath, name: '识别模型' },
{ path: this.clsModelPath, name: '分类模型' },
{ path: this.keysPath, name: '字符集文件' }
];
for (const { path: filePath, name } of requiredFiles) {
const exists = await fse.pathExists(filePath);
if (!exists) {
throw new Error(`模型文件不存在: ${filePath}`);
}
this.logger.debug(`验证通过: ${name}`);
}
this.logger.info('所有模型文件验证通过');
}
async recognizeImage(imagePath, config = {}) {
if (!this.isInitialized) {
await this.initialize(config);
}
if (!imagePath || typeof imagePath !== 'string') {
throw new Error(`无效的图片路径: ${imagePath}`);
}
if (!fse.existsSync(imagePath)) {
throw new Error(`图片文件不存在: ${imagePath}`);
}
try {
this.logger.info(`开始OCR识别: ${path.basename(imagePath)}`);
const startTime = Date.now();
const preprocessResult = await this.imagePreprocessor.preprocessWithPadding(imagePath, config);
const { processedImage } = preprocessResult;
const textBoxes = await this.detectionProcessor.detectText(processedImage);
// 在原始图像上绘制文本框
await this.drawTextBoxesOnOriginalImage(imagePath, textBoxes, processedImage);
const recognitionResults = await this.recognitionProcessor.recognizeTextWithCls(processedImage, textBoxes);
const processingTime = Date.now() - startTime;
const textBlocks = this.textPostProcessor.buildTextBlocks(recognitionResults);
const imageInfo = await this.imagePreprocessor.getImageInfo(imagePath);
const rawText = textBlocks.map(block => block.content).join('\n');
const overallConfidence = this.textPostProcessor.calculateOverallConfidence(recognitionResults);
const result = {
textBlocks,
confidence: overallConfidence,
processingTime,
isOffline: true,
imagePath,
totalPages: 1,
rawText,
imageInfo,
recognitionCount: recognitionResults.length,
detectionCount: textBoxes.length,
visualizationPath: this.getVisualizationPath(imagePath)
};
this.logger.info(`OCR识别完成:
- 处理时间: ${processingTime}ms
- 检测区域: ${textBoxes.length}
- 成功识别: ${recognitionResults.length}
- 总体置信度: ${overallConfidence.toFixed(4)}
- 最终文本: ${rawText.length}字符
- 可视化图像: ${result.visualizationPath}`);
return result;
} catch (error) {
this.logger.error(`OCR识别失败: ${error.message}`);
throw new Error(`OCR识别失败: ${error.message}`);
}
}
async drawTextBoxesOnOriginalImage(originalImagePath, textBoxes, processedImage) {
try {
this.logger.info('开始在原始图像上绘制文本框');
// 读取原始图像
const originalImage = sharp(originalImagePath);
const metadata = await originalImage.metadata();
// 创建SVG绘制指令
const svgOverlay = this.createTextBoxesSVG(textBoxes, processedImage, metadata);
// 将SVG叠加到原始图像上
const visualizationPath = this.getVisualizationPath(originalImagePath);
await originalImage
.composite([{
input: Buffer.from(svgOverlay),
top: 0,
left: 0
}])
.png()
.toFile(visualizationPath);
this.logger.info(`文本框可视化图像已保存: ${visualizationPath}`);
} catch (error) {
this.logger.error('绘制文本框失败', error);
}
}
createTextBoxesSVG(textBoxes, processedImage, originalMetadata) {
const { width, height } = originalMetadata;
let svg = `<svg width="${width}" height="${height}" xmlns="http://www.w3.org/2000/svg">`;
// 定义样式
svg += `
<style>
.text-box {
fill: none;
stroke: #ff0000;
stroke-width: 2;
}
.text-box-high-conf {
fill: none;
stroke: #00ff00;
stroke-width: 2;
}
.text-label {
font-size: 12px;
fill: #ff0000;
font-family: Arial, sans-serif;
}
</style>
`;
textBoxes.forEach((box, index) => {
// 将处理后的图像坐标转换回原始图像坐标
const originalBox = this.scaleBoxToOriginalImage(box, processedImage);
// 根据置信度选择颜色
const boxClass = box.confidence > 0.8 ? 'text-box-high-conf' : 'text-box';
// 绘制文本框(多边形)
const points = [
`${originalBox.x1},${originalBox.y1}`,
`${originalBox.x2},${originalBox.y2}`,
`${originalBox.x3},${originalBox.y3}`,
`${originalBox.x4},${originalBox.y4}`
].join(' ');
svg += `<polygon class="${boxClass}" points="${points}" />`;
// 在框上方添加索引和置信度标签
const labelX = Math.min(originalBox.x1, originalBox.x2, originalBox.x3, originalBox.x4);
const labelY = Math.min(originalBox.y1, originalBox.y2, originalBox.y3, originalBox.y4) - 5;
if (labelY > 15) { // 确保标签在图像范围内
svg += `<text class="text-label" x="${labelX}" y="${labelY}">${index + 1} (${box.confidence.toFixed(2)})</text>`;
}
});
svg += '</svg>';
return svg;
}
scaleBoxToOriginalImage(box, processedImage) {
const {
scaleX, scaleY,
paddingX, paddingY,
originalWidth, originalHeight
} = processedImage;
// 将处理后的图像坐标转换回填充后的图像坐标
const paddedX1 = box.x1 * scaleX;
const paddedY1 = box.y1 * scaleY;
const paddedX2 = box.x2 * scaleX;
const paddedY2 = box.y2 * scaleY;
const paddedX3 = box.x3 * scaleX;
const paddedY3 = box.y3 * scaleY;
const paddedX4 = box.x4 * scaleX;
const paddedY4 = box.y4 * scaleY;
// 去除填充,得到原始图像坐标
const originalX1 = paddedX1 - paddingX;
const originalY1 = paddedY1 - paddingY;
const originalX2 = paddedX2 - paddingX;
const originalY2 = paddedY2 - paddingY;
const originalX3 = paddedX3 - paddingX;
const originalY3 = paddedY3 - paddingY;
const originalX4 = paddedX4 - paddingX;
const originalY4 = paddedY4 - paddingY;
const clamp = (value, max) => Math.max(0, Math.min(max, value));
return {
x1: clamp(originalX1, originalWidth - 1),
y1: clamp(originalY1, originalHeight - 1),
x2: clamp(originalX2, originalWidth - 1),
y2: clamp(originalY2, originalHeight - 1),
x3: clamp(originalX3, originalWidth - 1),
y3: clamp(originalY3, originalHeight - 1),
x4: clamp(originalX4, originalWidth - 1),
y4: clamp(originalY4, originalHeight - 1),
confidence: box.confidence
};
}
getVisualizationPath(originalImagePath) {
const originalName = path.basename(originalImagePath, path.extname(originalImagePath));
const timestamp = Date.now();
return path.join(this.visualizationDir, `${originalName}-detection-${timestamp}.png`);
}
getStatus() {
return {
isInitialized: this.isInitialized,
isOffline: true,
engine: 'PP-OCRv3 (ONNX Runtime)',
version: '2.0.0',
models: {
detection: path.relative(process.cwd(), this.detModelPath),
recognition: path.relative(process.cwd(), this.recModelPath),
classification: path.relative(process.cwd(), this.clsModelPath),
characterSet: this.recognitionProcessor.getCharacterSetSize()
},
config: {
detThresh: this.defaultConfig.detThresh,
detBoxThresh: this.defaultConfig.detBoxThresh,
clsThresh: this.defaultConfig.clsThresh,
preprocessing: 'enhanced with smart padding'
},
backend: 'CPU'
};
}
}
const onnxOcrManager = new OnnxOcrManager();
export default onnxOcrManager;