85 行
2.7 KiB
JavaScript
85 行
2.7 KiB
JavaScript
|
|
// server/utils/textDirectionClassifier.js
|
||
|
|
import { Tensor } from 'onnxruntime-node';
|
||
|
|
import sharp from 'sharp';
|
||
|
|
|
||
|
|
class TextDirectionClassifier {
|
||
|
|
constructor() {
|
||
|
|
this.clsSession = null;
|
||
|
|
this.config = null;
|
||
|
|
}
|
||
|
|
|
||
|
|
initialize(clsSession, config) {
|
||
|
|
this.clsSession = clsSession;
|
||
|
|
this.config = config;
|
||
|
|
}
|
||
|
|
|
||
|
|
async classifyTextDirection(textRegionBuffer) {
|
||
|
|
try {
|
||
|
|
const inputTensor = await this.prepareClsInput(textRegionBuffer);
|
||
|
|
const outputs = await this.clsSession.run({ [this.clsSession.inputNames[0]]: inputTensor });
|
||
|
|
return this.postprocessCls(outputs);
|
||
|
|
} catch (error) {
|
||
|
|
console.error('文本方向分类失败:', error);
|
||
|
|
return { clsResult: 0, clsConfidence: 1.0 };
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
async prepareClsInput(textRegionBuffer) {
|
||
|
|
const targetHeight = 48;
|
||
|
|
const targetWidth = 192;
|
||
|
|
|
||
|
|
const resizedBuffer = await sharp(textRegionBuffer)
|
||
|
|
.resize(targetWidth, targetHeight)
|
||
|
|
.png()
|
||
|
|
.toBuffer();
|
||
|
|
|
||
|
|
const imageData = await sharp(resizedBuffer)
|
||
|
|
.ensureAlpha()
|
||
|
|
.raw()
|
||
|
|
.toBuffer({ resolveWithObject: true });
|
||
|
|
|
||
|
|
const inputData = new Float32Array(3 * targetHeight * targetWidth);
|
||
|
|
const data = imageData.data;
|
||
|
|
const channels = imageData.info.channels;
|
||
|
|
|
||
|
|
for (let i = 0; i < data.length; i += channels) {
|
||
|
|
const pixelIndex = Math.floor(i / channels);
|
||
|
|
const channel = Math.floor(pixelIndex / (targetHeight * targetWidth));
|
||
|
|
const posInChannel = pixelIndex % (targetHeight * targetWidth);
|
||
|
|
|
||
|
|
if (channel < 3) {
|
||
|
|
const y = Math.floor(posInChannel / targetWidth);
|
||
|
|
const x = posInChannel % targetWidth;
|
||
|
|
const inputIndex = channel * targetHeight * targetWidth + y * targetWidth + x;
|
||
|
|
|
||
|
|
if (inputIndex < inputData.length) {
|
||
|
|
inputData[inputIndex] = data[i] / 255.0;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return new Tensor('float32', inputData, [1, 3, targetHeight, targetWidth]);
|
||
|
|
}
|
||
|
|
|
||
|
|
postprocessCls(outputs) {
|
||
|
|
const outputNames = this.clsSession.outputNames;
|
||
|
|
const clsOutput = outputs[outputNames[0]];
|
||
|
|
|
||
|
|
if (!clsOutput) return { clsResult: 0, clsConfidence: 1.0 };
|
||
|
|
|
||
|
|
const data = clsOutput.data;
|
||
|
|
|
||
|
|
let clsResult = 0;
|
||
|
|
let clsConfidence = data[0];
|
||
|
|
|
||
|
|
if (data.length >= 2 && data[1] > data[0]) {
|
||
|
|
clsResult = 180;
|
||
|
|
clsConfidence = data[1];
|
||
|
|
}
|
||
|
|
|
||
|
|
console.log(`🧭 文本方向分类: ${clsResult}°, 置信度: ${clsConfidence.toFixed(4)}`);
|
||
|
|
return { clsResult, clsConfidence };
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
export default TextDirectionClassifier;
|