detectionProcessor.js 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. // server/utils/detectionProcessor.js
  2. import { Tensor } from 'onnxruntime-node';
  3. import sharp from 'sharp';
  4. class DetectionProcessor {
  5. constructor() {
  6. this.session = null;
  7. this.config = null;
  8. }
  9. initialize(session, config) {
  10. this.session = session;
  11. this.config = config;
  12. }
  13. async detectText(processedImage) {
  14. try {
  15. const inputTensor = await this.prepareDetectionInput(processedImage);
  16. const outputs = await this.session.run({ [this.session.inputNames[0]]: inputTensor });
  17. const textBoxes = this.postprocessDetection(outputs, processedImage);
  18. return textBoxes;
  19. } catch (error) {
  20. console.error('文本检测失败:', error);
  21. return [];
  22. }
  23. }
  24. async prepareDetectionInput(processedImage) {
  25. const { buffer, width, height } = processedImage;
  26. const imageData = await sharp(buffer)
  27. .ensureAlpha()
  28. .raw()
  29. .toBuffer({ resolveWithObject: true });
  30. const inputData = new Float32Array(3 * height * width);
  31. const data = imageData.data;
  32. const channels = imageData.info.channels;
  33. for (let i = 0; i < data.length; i += channels) {
  34. const pixelIndex = Math.floor(i / channels);
  35. const channel = Math.floor(pixelIndex / (height * width));
  36. const posInChannel = pixelIndex % (height * width);
  37. if (channel < 3) {
  38. const y = Math.floor(posInChannel / width);
  39. const x = posInChannel % width;
  40. const inputIndex = channel * height * width + y * width + x;
  41. if (inputIndex < inputData.length) {
  42. inputData[inputIndex] = data[i] / 255.0;
  43. }
  44. }
  45. }
  46. return new Tensor('float32', inputData, [1, 3, height, width]);
  47. }
  48. postprocessDetection(outputs, processedImage) {
  49. try {
  50. const boxes = [];
  51. const outputNames = this.session.outputNames;
  52. const detectionOutput = outputs[outputNames[0]];
  53. if (!detectionOutput) {
  54. return boxes;
  55. }
  56. const [batch, channels, height, width] = detectionOutput.dims;
  57. const data = detectionOutput.data;
  58. // 降低检测阈值,提高召回率
  59. const threshold = this.config.detThresh || 0.05;
  60. const points = [];
  61. // 改进的点收集逻辑
  62. for (let y = 0; y < height; y++) {
  63. for (let x = 0; x < width; x++) {
  64. const idx = y * width + x;
  65. const prob = data[idx];
  66. if (prob > threshold) {
  67. points.push({
  68. x,
  69. y,
  70. prob,
  71. localMax: this.isLocalMaximum(data, x, y, width, height, 2)
  72. });
  73. }
  74. }
  75. }
  76. if (points.length === 0) {
  77. return boxes;
  78. }
  79. // 改进的聚类算法
  80. const clusters = this.enhancedCluster(points, 8);
  81. for (const cluster of clusters) {
  82. // 降低最小点数要求
  83. if (cluster.length < 2) continue;
  84. const minX = Math.min(...cluster.map(p => p.x));
  85. const maxX = Math.max(...cluster.map(p => p.x));
  86. const minY = Math.min(...cluster.map(p => p.y));
  87. const maxY = Math.max(...cluster.map(p => p.y));
  88. const boxWidth = maxX - minX;
  89. const boxHeight = maxY - minY;
  90. // 放宽尺寸限制
  91. if (boxWidth < 2 || boxHeight < 2) continue;
  92. const aspectRatio = boxWidth / boxHeight;
  93. // 放宽宽高比限制
  94. if (aspectRatio > 100 || aspectRatio < 0.01) continue;
  95. const avgConfidence = cluster.reduce((sum, p) => sum + p.prob, 0) / cluster.length;
  96. // 降低框置信度阈值
  97. const boxThreshold = this.config.detBoxThresh || 0.1;
  98. if (avgConfidence > boxThreshold) {
  99. const box = this.scaleBoxToProcessedImage({
  100. x1: minX, y1: minY,
  101. x2: maxX, y2: minY,
  102. x3: maxX, y3: maxY,
  103. x4: minX, y4: maxY
  104. }, processedImage);
  105. box.confidence = avgConfidence;
  106. boxes.push(box);
  107. }
  108. }
  109. boxes.sort((a, b) => b.confidence - a.confidence);
  110. console.log(`✅ 检测到 ${boxes.length} 个文本区域`);
  111. return boxes;
  112. } catch (error) {
  113. console.error('检测后处理错误:', error);
  114. return [];
  115. }
  116. }
  117. // 添加局部最大值检测
  118. isLocalMaximum(data, x, y, width, height, radius) {
  119. const centerProb = data[y * width + x];
  120. for (let dy = -radius; dy <= radius; dy++) {
  121. for (let dx = -radius; dx <= radius; dx++) {
  122. if (dx === 0 && dy === 0) continue;
  123. const nx = x + dx;
  124. const ny = y + dy;
  125. if (nx >= 0 && nx < width && ny >= 0 && ny < height) {
  126. if (data[ny * width + nx] > centerProb) {
  127. return false;
  128. }
  129. }
  130. }
  131. }
  132. return true;
  133. }
  134. // 改进的聚类算法
  135. enhancedCluster(points, distanceThreshold) {
  136. const clusters = [];
  137. const visited = new Set();
  138. // 按概率降序排序,优先处理高置信度点
  139. const sortedPoints = [...points].sort((a, b) => b.prob - a.prob);
  140. for (let i = 0; i < sortedPoints.length; i++) {
  141. if (visited.has(i)) continue;
  142. const cluster = [];
  143. const queue = [i];
  144. visited.add(i);
  145. while (queue.length > 0) {
  146. const currentIndex = queue.shift();
  147. const currentPoint = sortedPoints[currentIndex];
  148. cluster.push(currentPoint);
  149. // 动态调整搜索半径
  150. const adaptiveThreshold = distanceThreshold *
  151. (1 + (1 - currentPoint.prob) * 0.5);
  152. for (let j = 0; j < sortedPoints.length; j++) {
  153. if (visited.has(j)) continue;
  154. const targetPoint = sortedPoints[j];
  155. const dist = Math.sqrt(
  156. Math.pow(targetPoint.x - currentPoint.x, 2) +
  157. Math.pow(targetPoint.y - currentPoint.y, 2)
  158. );
  159. if (dist < adaptiveThreshold) {
  160. queue.push(j);
  161. visited.add(j);
  162. }
  163. }
  164. }
  165. if (cluster.length > 0) {
  166. clusters.push(cluster);
  167. }
  168. }
  169. return clusters;
  170. }
  171. scaleBoxToProcessedImage(box, processedImage) {
  172. const { width: processedWidth, height: processedHeight } = processedImage;
  173. const scaledBox = {
  174. x1: box.x1,
  175. y1: box.y1,
  176. x2: box.x2,
  177. y2: box.y2,
  178. x3: box.x3,
  179. y3: box.y3,
  180. x4: box.x4,
  181. y4: box.y4
  182. };
  183. const clamp = (value, max) => Math.max(0, Math.min(max, value));
  184. return {
  185. x1: clamp(scaledBox.x1, processedWidth - 1),
  186. y1: clamp(scaledBox.y1, processedHeight - 1),
  187. x2: clamp(scaledBox.x2, processedWidth - 1),
  188. y2: clamp(scaledBox.y2, processedHeight - 1),
  189. x3: clamp(scaledBox.x3, processedWidth - 1),
  190. y3: clamp(scaledBox.y3, processedHeight - 1),
  191. x4: clamp(scaledBox.x4, processedWidth - 1),
  192. y4: clamp(scaledBox.y4, processedHeight - 1)
  193. };
  194. }
  195. scaleBoxToOriginalImage(box, processedImage) {
  196. const {
  197. scaleX, scaleY,
  198. paddingX, paddingY,
  199. originalWidth, originalHeight
  200. } = processedImage;
  201. const paddedX1 = box.x1 * scaleX;
  202. const paddedY1 = box.y1 * scaleY;
  203. const paddedX3 = box.x3 * scaleX;
  204. const paddedY3 = box.y3 * scaleY;
  205. const originalX1 = paddedX1 - paddingX;
  206. const originalY1 = paddedY1 - paddingY;
  207. const originalX3 = paddedX3 - paddingX;
  208. const originalY3 = paddedY3 - paddingY;
  209. const clamp = (value, max) => Math.max(0, Math.min(max, value));
  210. return {
  211. x1: clamp(originalX1, originalWidth - 1),
  212. y1: clamp(originalY1, originalHeight - 1),
  213. x2: clamp(originalX3, originalWidth - 1),
  214. y2: clamp(originalY1, originalHeight - 1),
  215. x3: clamp(originalX3, originalWidth - 1),
  216. y3: clamp(originalY3, originalHeight - 1),
  217. x4: clamp(originalX1, originalWidth - 1),
  218. y4: clamp(originalY3, originalHeight - 1),
  219. confidence: box.confidence
  220. };
  221. }
  222. }
  223. export default DetectionProcessor;