Electron-vue3-ts-offline/server/utils/textDirectionClassifier.js
2025-11-13 16:34:41 +08:00

85 行
2.7 KiB
JavaScript

// server/utils/textDirectionClassifier.js
import { Tensor } from 'onnxruntime-node';
import sharp from 'sharp';
class TextDirectionClassifier {
constructor() {
this.clsSession = null;
this.config = null;
}
initialize(clsSession, config) {
this.clsSession = clsSession;
this.config = config;
}
async classifyTextDirection(textRegionBuffer) {
try {
const inputTensor = await this.prepareClsInput(textRegionBuffer);
const outputs = await this.clsSession.run({ [this.clsSession.inputNames[0]]: inputTensor });
return this.postprocessCls(outputs);
} catch (error) {
console.error('文本方向分类失败:', error);
return { clsResult: 0, clsConfidence: 1.0 };
}
}
async prepareClsInput(textRegionBuffer) {
const targetHeight = 48;
const targetWidth = 192;
const resizedBuffer = await sharp(textRegionBuffer)
.resize(targetWidth, targetHeight)
.png()
.toBuffer();
const imageData = await sharp(resizedBuffer)
.ensureAlpha()
.raw()
.toBuffer({ resolveWithObject: true });
const inputData = new Float32Array(3 * targetHeight * targetWidth);
const data = imageData.data;
const channels = imageData.info.channels;
for (let i = 0; i < data.length; i += channels) {
const pixelIndex = Math.floor(i / channels);
const channel = Math.floor(pixelIndex / (targetHeight * targetWidth));
const posInChannel = pixelIndex % (targetHeight * targetWidth);
if (channel < 3) {
const y = Math.floor(posInChannel / targetWidth);
const x = posInChannel % targetWidth;
const inputIndex = channel * targetHeight * targetWidth + y * targetWidth + x;
if (inputIndex < inputData.length) {
inputData[inputIndex] = data[i] / 255.0;
}
}
}
return new Tensor('float32', inputData, [1, 3, targetHeight, targetWidth]);
}
postprocessCls(outputs) {
const outputNames = this.clsSession.outputNames;
const clsOutput = outputs[outputNames[0]];
if (!clsOutput) return { clsResult: 0, clsConfidence: 1.0 };
const data = clsOutput.data;
let clsResult = 0;
let clsConfidence = data[0];
if (data.length >= 2 && data[1] > data[0]) {
clsResult = 180;
clsConfidence = data[1];
}
console.log(`🧭 文本方向分类: ${clsResult}°, 置信度: ${clsConfidence.toFixed(4)}`);
return { clsResult, clsConfidence };
}
}
export default TextDirectionClassifier;