textDirectionClassifier.js 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. // server/utils/textDirectionClassifier.js
  2. import { Tensor } from 'onnxruntime-node';
  3. import sharp from 'sharp';
  4. class TextDirectionClassifier {
  5. constructor() {
  6. this.clsSession = null;
  7. this.config = null;
  8. }
  9. initialize(clsSession, config) {
  10. this.clsSession = clsSession;
  11. this.config = config;
  12. }
  13. async classifyTextDirection(textRegionBuffer) {
  14. try {
  15. const inputTensor = await this.prepareClsInput(textRegionBuffer);
  16. const outputs = await this.clsSession.run({ [this.clsSession.inputNames[0]]: inputTensor });
  17. return this.postprocessCls(outputs);
  18. } catch (error) {
  19. console.error('文本方向分类失败:', error);
  20. return { clsResult: 0, clsConfidence: 1.0 };
  21. }
  22. }
  23. async prepareClsInput(textRegionBuffer) {
  24. const targetHeight = 48;
  25. const targetWidth = 192;
  26. const resizedBuffer = await sharp(textRegionBuffer)
  27. .resize(targetWidth, targetHeight)
  28. .png()
  29. .toBuffer();
  30. const imageData = await sharp(resizedBuffer)
  31. .ensureAlpha()
  32. .raw()
  33. .toBuffer({ resolveWithObject: true });
  34. const inputData = new Float32Array(3 * targetHeight * targetWidth);
  35. const data = imageData.data;
  36. const channels = imageData.info.channels;
  37. for (let i = 0; i < data.length; i += channels) {
  38. const pixelIndex = Math.floor(i / channels);
  39. const channel = Math.floor(pixelIndex / (targetHeight * targetWidth));
  40. const posInChannel = pixelIndex % (targetHeight * targetWidth);
  41. if (channel < 3) {
  42. const y = Math.floor(posInChannel / targetWidth);
  43. const x = posInChannel % targetWidth;
  44. const inputIndex = channel * targetHeight * targetWidth + y * targetWidth + x;
  45. if (inputIndex < inputData.length) {
  46. inputData[inputIndex] = data[i] / 255.0;
  47. }
  48. }
  49. }
  50. return new Tensor('float32', inputData, [1, 3, targetHeight, targetWidth]);
  51. }
  52. postprocessCls(outputs) {
  53. const outputNames = this.clsSession.outputNames;
  54. const clsOutput = outputs[outputNames[0]];
  55. if (!clsOutput) return { clsResult: 0, clsConfidence: 1.0 };
  56. const data = clsOutput.data;
  57. let clsResult = 0;
  58. let clsConfidence = data[0];
  59. if (data.length >= 2 && data[1] > data[0]) {
  60. clsResult = 180;
  61. clsConfidence = data[1];
  62. }
  63. console.log(`🧭 文本方向分类: ${clsResult}°, 置信度: ${clsConfidence.toFixed(4)}`);
  64. return { clsResult, clsConfidence };
  65. }
  66. }
  67. export default TextDirectionClassifier;