detectionProcessor.js 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. // server/utils/detectionProcessor.js
  2. import { Tensor } from 'onnxruntime-node';
  3. import sharp from 'sharp';
  4. class DetectionProcessor {
  5. constructor() {
  6. this.session = null;
  7. this.config = null;
  8. this.logger = {
  9. info: (msg, ...args) => console.log(`🔍 [检测] ${msg}`, ...args),
  10. error: (msg, ...args) => console.error(`❌ [检测] ${msg}`, ...args),
  11. debug: (msg, ...args) => console.log(`🐛 [检测] ${msg}`, ...args)
  12. };
  13. }
  14. initialize(session, config) {
  15. this.session = session;
  16. this.config = config;
  17. this.logger.info('检测处理器初始化完成');
  18. }
  19. async detectText(processedImage) {
  20. const startTime = Date.now();
  21. this.logger.info('开始文本检测');
  22. try {
  23. const inputTensor = await this.prepareDetectionInput(processedImage);
  24. const outputs = await this.session.run({ [this.session.inputNames[0]]: inputTensor });
  25. const textBoxes = this.postprocessDetection(outputs, processedImage);
  26. const processingTime = Date.now() - startTime;
  27. this.logger.info(`检测完成: ${textBoxes.length}个区域, 耗时${processingTime}ms`);
  28. return textBoxes;
  29. } catch (error) {
  30. this.logger.error('检测失败', error);
  31. return [];
  32. }
  33. }
  34. async prepareDetectionInput(processedImage) {
  35. const { buffer, width, height } = processedImage;
  36. this.logger.debug(`准备检测输入: ${width}x${height}`);
  37. const imageData = await sharp(buffer)
  38. .ensureAlpha()
  39. .raw()
  40. .toBuffer({ resolveWithObject: true });
  41. const inputData = new Float32Array(3 * height * width);
  42. const data = imageData.data;
  43. const channels = imageData.info.channels;
  44. // 优化数据填充逻辑
  45. for (let i = 0; i < data.length; i += channels) {
  46. const pixelIndex = Math.floor(i / channels);
  47. const y = Math.floor(pixelIndex / width);
  48. const x = pixelIndex % width;
  49. for (let c = 0; c < 3; c++) {
  50. const inputIndex = c * height * width + y * width + x;
  51. if (inputIndex < inputData.length) {
  52. inputData[inputIndex] = data[i] / 255.0;
  53. }
  54. }
  55. }
  56. this.logger.debug('检测输入张量准备完成');
  57. return new Tensor('float32', inputData, [1, 3, height, width]);
  58. }
  59. postprocessDetection(outputs, processedImage) {
  60. this.logger.debug('开始检测后处理');
  61. try {
  62. const boxes = [];
  63. const outputNames = this.session.outputNames;
  64. const detectionOutput = outputs[outputNames[0]];
  65. if (!detectionOutput) {
  66. this.logger.debug('检测输出为空');
  67. return boxes;
  68. }
  69. const [batch, channels, height, width] = detectionOutput.dims;
  70. const data = detectionOutput.data;
  71. // 动态阈值调整
  72. const baseThreshold = this.config.detThresh || 0.05;
  73. const adaptiveThreshold = this.calculateAdaptiveThreshold(data, baseThreshold);
  74. this.logger.debug(`使用检测阈值: ${adaptiveThreshold.toFixed(4)}`);
  75. const points = this.collectDetectionPoints(data, width, height, adaptiveThreshold);
  76. if (points.length === 0) {
  77. this.logger.debug('未检测到有效文本点');
  78. return boxes;
  79. }
  80. this.logger.debug(`收集到 ${points.length} 个检测点`);
  81. const clusters = this.enhancedCluster(points, this.config.clusterDistance || 8);
  82. this.logger.debug(`聚类得到 ${clusters.length} 个区域`);
  83. const validBoxes = this.filterAndScaleBoxes(clusters, processedImage);
  84. this.logger.info(`生成 ${validBoxes.length} 个有效文本框`);
  85. return validBoxes.sort((a, b) => b.confidence - a.confidence);
  86. } catch (error) {
  87. this.logger.error('检测后处理错误', error);
  88. return [];
  89. }
  90. }
  91. collectDetectionPoints(data, width, height, threshold) {
  92. const points = [];
  93. let totalProb = 0;
  94. let maxProb = 0;
  95. for (let y = 0; y < height; y++) {
  96. for (let x = 0; x < width; x++) {
  97. const idx = y * width + x;
  98. const prob = data[idx];
  99. if (prob > threshold) {
  100. totalProb += prob;
  101. maxProb = Math.max(maxProb, prob);
  102. points.push({
  103. x, y, prob,
  104. localMax: this.isLocalMaximum(data, x, y, width, height, 2)
  105. });
  106. }
  107. }
  108. }
  109. if (points.length > 0) {
  110. this.logger.debug(`检测点统计: 平均置信度 ${(totalProb/points.length).toFixed(4)}, 最大置信度 ${maxProb.toFixed(4)}`);
  111. }
  112. return points;
  113. }
  114. calculateAdaptiveThreshold(data, baseThreshold) {
  115. // 基于图像特性动态调整阈值
  116. let sum = 0;
  117. let count = 0;
  118. const sampleSize = Math.min(1000, data.length);
  119. for (let i = 0; i < sampleSize; i++) {
  120. const idx = Math.floor(Math.random() * data.length);
  121. if (data[idx] > baseThreshold) {
  122. sum += data[idx];
  123. count++;
  124. }
  125. }
  126. if (count === 0) return baseThreshold;
  127. const mean = sum / count;
  128. return Math.min(baseThreshold * 1.5, mean * 0.8);
  129. }
  130. filterAndScaleBoxes(clusters, processedImage) {
  131. const boxes = [];
  132. const minPoints = this.config.minClusterPoints || 2;
  133. const boxThreshold = this.config.detBoxThresh || 0.1;
  134. for (const cluster of clusters) {
  135. if (cluster.length < minPoints) continue;
  136. const minX = Math.min(...cluster.map(p => p.x));
  137. const maxX = Math.max(...cluster.map(p => p.x));
  138. const minY = Math.min(...cluster.map(p => p.y));
  139. const maxY = Math.max(...cluster.map(p => p.y));
  140. const boxWidth = maxX - minX;
  141. const boxHeight = maxY - minY;
  142. // 放宽尺寸限制,提高小文本检测
  143. if (boxWidth < 1 || boxHeight < 1) continue;
  144. const aspectRatio = boxWidth / boxHeight;
  145. if (aspectRatio > 150 || aspectRatio < 0.005) continue;
  146. const avgConfidence = cluster.reduce((sum, p) => sum + p.prob, 0) / cluster.length;
  147. if (avgConfidence > boxThreshold) {
  148. const box = this.scaleBoxToProcessedImage({
  149. x1: minX, y1: minY,
  150. x2: maxX, y2: minY,
  151. x3: maxX, y3: maxY,
  152. x4: minX, y4: maxY
  153. }, processedImage);
  154. box.confidence = avgConfidence;
  155. boxes.push(box);
  156. }
  157. }
  158. return boxes;
  159. }
  160. isLocalMaximum(data, x, y, width, height, radius) {
  161. const centerProb = data[y * width + x];
  162. for (let dy = -radius; dy <= radius; dy++) {
  163. for (let dx = -radius; dx <= radius; dx++) {
  164. if (dx === 0 && dy === 0) continue;
  165. const nx = x + dx;
  166. const ny = y + dy;
  167. if (nx >= 0 && nx < width && ny >= 0 && ny < height) {
  168. if (data[ny * width + nx] > centerProb) {
  169. return false;
  170. }
  171. }
  172. }
  173. }
  174. return true;
  175. }
  176. enhancedCluster(points, distanceThreshold) {
  177. const clusters = [];
  178. const visited = new Set();
  179. const sortedPoints = [...points].sort((a, b) => b.prob - a.prob);
  180. for (let i = 0; i < sortedPoints.length; i++) {
  181. if (visited.has(i)) continue;
  182. const cluster = [];
  183. const queue = [i];
  184. visited.add(i);
  185. while (queue.length > 0) {
  186. const currentIndex = queue.shift();
  187. const currentPoint = sortedPoints[currentIndex];
  188. cluster.push(currentPoint);
  189. // 动态调整搜索半径
  190. const adaptiveThreshold = distanceThreshold * (1 + (1 - currentPoint.prob) * 0.3);
  191. for (let j = 0; j < sortedPoints.length; j++) {
  192. if (visited.has(j)) continue;
  193. const targetPoint = sortedPoints[j];
  194. const dist = Math.sqrt(
  195. Math.pow(targetPoint.x - currentPoint.x, 2) +
  196. Math.pow(targetPoint.y - currentPoint.y, 2)
  197. );
  198. if (dist < adaptiveThreshold) {
  199. queue.push(j);
  200. visited.add(j);
  201. }
  202. }
  203. }
  204. if (cluster.length > 0) {
  205. clusters.push(cluster);
  206. }
  207. }
  208. return clusters;
  209. }
  210. scaleBoxToProcessedImage(box, processedImage) {
  211. const { width: processedWidth, height: processedHeight } = processedImage;
  212. const clamp = (value, max) => Math.max(0, Math.min(max, value));
  213. return {
  214. x1: clamp(box.x1, processedWidth - 1),
  215. y1: clamp(box.y1, processedHeight - 1),
  216. x2: clamp(box.x2, processedWidth - 1),
  217. y2: clamp(box.y2, processedHeight - 1),
  218. x3: clamp(box.x3, processedWidth - 1),
  219. y3: clamp(box.y3, processedHeight - 1),
  220. x4: clamp(box.x4, processedWidth - 1),
  221. y4: clamp(box.y4, processedHeight - 1)
  222. };
  223. }
  224. }
  225. export default DetectionProcessor;