XuqmGroup-Server/tenant-service/src/main/java/com/xuqm/tenant/controller/DatabaseController.java

193 行
8.6 KiB
Java

package com.xuqm.tenant.controller;
import com.xuqm.tenant.config.PrivateDeploymentProperties;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@RestController
@RequestMapping("/api/system/database")
public class DatabaseController {
private final PrivateDeploymentProperties deployProps;
private final DataSource dataSource;
public DatabaseController(PrivateDeploymentProperties deployProps, DataSource dataSource) {
this.deployProps = deployProps;
this.dataSource = dataSource;
}
@GetMapping("/tables")
public ResponseEntity<?> listTables() {
if (!deployProps.isPrivate()) {
return ResponseEntity.status(403).body(Map.of("message", "此接口仅在私有化部署可用"));
}
try (Connection conn = dataSource.getConnection()) {
DatabaseMetaData meta = conn.getMetaData();
String catalog = conn.getCatalog();
ResultSet rs = meta.getTables(catalog, null, "%", new String[]{"TABLE"});
List<Map<String, Object>> tables = new ArrayList<>();
while (rs.next()) {
Map<String, Object> t = new LinkedHashMap<>();
t.put("name", rs.getString("TABLE_NAME"));
t.put("comment", rs.getString("REMARKS"));
tables.add(t);
}
return ResponseEntity.ok(Map.of("data", tables));
} catch (SQLException e) {
return ResponseEntity.status(500).body(Map.of("message", "查询表列表失败: " + e.getMessage()));
}
}
@GetMapping("/tables/{tableName}/columns")
public ResponseEntity<?> listColumns(@PathVariable String tableName) {
if (!deployProps.isPrivate()) {
return ResponseEntity.status(403).body(Map.of("message", "此接口仅在私有化部署可用"));
}
if (!isAllowedTable(tableName)) {
return ResponseEntity.status(403).body(Map.of("message", "不允许访问该表"));
}
try (Connection conn = dataSource.getConnection()) {
DatabaseMetaData meta = conn.getMetaData();
String catalog = conn.getCatalog();
ResultSet rs = meta.getColumns(catalog, null, tableName, "%");
List<Map<String, Object>> columns = new ArrayList<>();
while (rs.next()) {
Map<String, Object> c = new LinkedHashMap<>();
c.put("name", rs.getString("COLUMN_NAME"));
c.put("type", rs.getString("TYPE_NAME"));
c.put("size", rs.getInt("COLUMN_SIZE"));
c.put("nullable", rs.getInt("NULLABLE") == DatabaseMetaData.columnNullable);
c.put("comment", rs.getString("REMARKS"));
columns.add(c);
}
return ResponseEntity.ok(Map.of("data", columns));
} catch (SQLException e) {
return ResponseEntity.status(500).body(Map.of("message", "查询列信息失败: " + e.getMessage()));
}
}
@GetMapping("/tables/{tableName}/data")
public ResponseEntity<?> queryData(
@PathVariable String tableName,
@RequestParam(defaultValue = "0") int page,
@RequestParam(defaultValue = "50") int size,
@RequestParam(required = false) String keyword,
@RequestParam(required = false) String sortColumn,
@RequestParam(required = false, defaultValue = "ASC") String sortDirection) {
if (!deployProps.isPrivate()) {
return ResponseEntity.status(403).body(Map.of("message", "此接口仅在私有化部署可用"));
}
if (!isAllowedTable(tableName)) {
return ResponseEntity.status(403).body(Map.of("message", "不允许访问该表"));
}
if (size > 200) size = 200;
try (Connection conn = dataSource.getConnection()) {
// Get columns info for keyword search
List<String> textColumns = new ArrayList<>();
DatabaseMetaData meta = conn.getMetaData();
String catalog = conn.getCatalog();
ResultSet colRs = meta.getColumns(catalog, null, tableName, "%");
List<String> allColumns = new ArrayList<>();
while (colRs.next()) {
String colName = colRs.getString("COLUMN_NAME");
String typeName = colRs.getString("TYPE_NAME");
allColumns.add(colName);
if (typeName != null && (typeName.contains("CHAR") || typeName.contains("TEXT") || typeName.contains("VARCHAR"))) {
textColumns.add(colName);
}
}
// Build WHERE clause for keyword search
StringBuilder whereClause = new StringBuilder();
List<Object> params = new ArrayList<>();
if (keyword != null && !keyword.isBlank() && !textColumns.isEmpty()) {
whereClause.append(" WHERE ");
for (int i = 0; i < textColumns.size(); i++) {
if (i > 0) whereClause.append(" OR ");
String safeCol = sanitizeIdentifier(textColumns.get(i));
whereClause.append(safeCol).append(" LIKE ?");
params.add("%" + keyword + "%");
}
}
// Count total
String countSql = "SELECT COUNT(*) FROM " + sanitizeIdentifier(tableName) + whereClause;
long total = 0;
try (PreparedStatement ps = conn.prepareStatement(countSql)) {
for (int i = 0; i < params.size(); i++) ps.setObject(i + 1, params.get(i));
try (ResultSet rs = ps.executeQuery()) {
if (rs.next()) total = rs.getLong(1);
}
}
// Build ORDER BY
String orderClause = "";
if (sortColumn != null && !sortColumn.isBlank() && allColumns.contains(sortColumn)) {
String dir = "DESC".equalsIgnoreCase(sortDirection) ? "DESC" : "ASC";
orderClause = " ORDER BY " + sanitizeIdentifier(sortColumn) + " " + dir;
}
// Query data with pagination
String dataSql = "SELECT * FROM " + sanitizeIdentifier(tableName) + whereClause + orderClause
+ " LIMIT ? OFFSET ?";
List<Map<String, Object>> rows = new ArrayList<>();
try (PreparedStatement ps = conn.prepareStatement(dataSql)) {
int idx = 1;
for (Object p : params) ps.setObject(idx++, p);
ps.setInt(idx++, size);
ps.setInt(idx, page * size);
try (ResultSet rs = ps.executeQuery()) {
ResultSetMetaData rsmd = rs.getMetaData();
int colCount = rsmd.getColumnCount();
while (rs.next()) {
Map<String, Object> row = new LinkedHashMap<>();
for (int i = 1; i <= colCount; i++) {
row.put(rsmd.getColumnName(i), rs.getObject(i));
}
rows.add(row);
}
}
}
int totalPages = size > 0 ? (int) Math.ceil((double) total / size) : 0;
Map<String, Object> result = new LinkedHashMap<>();
result.put("columns", allColumns);
result.put("rows", rows);
result.put("total", total);
result.put("totalPages", totalPages);
result.put("page", page);
result.put("size", size);
return ResponseEntity.ok(Map.of("data", result));
} catch (SQLException e) {
return ResponseEntity.status(500).body(Map.of("message", "查询数据失败: " + e.getMessage()));
}
}
private boolean isAllowedTable(String tableName) {
// Allow any table in the current database - the table name comes from our own listTables() API
return tableName != null && tableName.matches("[a-zA-Z0-9_]+");
}
private static String sanitizeIdentifier(String identifier) {
// Backtick-quote to prevent SQL injection
return "`" + identifier.replace("`", "``") + "`";
}
}