// server/utils/textDirectionClassifier.js import { Tensor } from 'onnxruntime-node'; import sharp from 'sharp'; class TextDirectionClassifier { constructor() { this.clsSession = null; this.config = null; } initialize(clsSession, config) { this.clsSession = clsSession; this.config = config; } async classifyTextDirection(textRegionBuffer) { try { const inputTensor = await this.prepareClsInput(textRegionBuffer); const outputs = await this.clsSession.run({ [this.clsSession.inputNames[0]]: inputTensor }); return this.postprocessCls(outputs); } catch (error) { console.error('文本方向分类失败:', error); return { clsResult: 0, clsConfidence: 1.0 }; } } async prepareClsInput(textRegionBuffer) { const targetHeight = 48; const targetWidth = 192; const resizedBuffer = await sharp(textRegionBuffer) .resize(targetWidth, targetHeight) .png() .toBuffer(); const imageData = await sharp(resizedBuffer) .ensureAlpha() .raw() .toBuffer({ resolveWithObject: true }); const inputData = new Float32Array(3 * targetHeight * targetWidth); const data = imageData.data; const channels = imageData.info.channels; for (let i = 0; i < data.length; i += channels) { const pixelIndex = Math.floor(i / channels); const channel = Math.floor(pixelIndex / (targetHeight * targetWidth)); const posInChannel = pixelIndex % (targetHeight * targetWidth); if (channel < 3) { const y = Math.floor(posInChannel / targetWidth); const x = posInChannel % targetWidth; const inputIndex = channel * targetHeight * targetWidth + y * targetWidth + x; if (inputIndex < inputData.length) { inputData[inputIndex] = data[i] / 255.0; } } } return new Tensor('float32', inputData, [1, 3, targetHeight, targetWidth]); } postprocessCls(outputs) { const outputNames = this.clsSession.outputNames; const clsOutput = outputs[outputNames[0]]; if (!clsOutput) return { clsResult: 0, clsConfidence: 1.0 }; const data = clsOutput.data; let clsResult = 0; let clsConfidence = data[0]; if (data.length >= 2 && data[1] > data[0]) { clsResult = 180; clsConfidence = data[1]; } console.log(`🧭 文本方向分类: ${clsResult}°, 置信度: ${clsConfidence.toFixed(4)}`); return { clsResult, clsConfidence }; } } export default TextDirectionClassifier;