package com.xuqm.sdk.im import com.google.gson.Gson import com.xuqm.sdk.im.listener.ImEventListener import com.xuqm.sdk.im.model.ImMessage import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.Response import okhttp3.WebSocket import okhttp3.WebSocketListener import java.net.URI import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.TimeUnit class ImClient( private val wsUrl: String, private val token: String, private val appId: String, ) { private var webSocket: WebSocket? = null private val listeners = CopyOnWriteArrayList() private val gson = Gson() private val subscriptions = mutableMapOf() private var subscriptionSeed = 0 private var connected = false private var inboundBuffer = StringBuilder() private val okhttp = OkHttpClient.Builder() .connectTimeout(10, TimeUnit.SECONDS) .readTimeout(0, TimeUnit.SECONDS) .build() fun connect() { disconnect(closeSocket = false) val request = Request.Builder() .url(wsUrl) .build() webSocket = okhttp.newWebSocket(request, object : WebSocketListener() { override fun onOpen(webSocket: WebSocket, response: Response) { sendConnectFrame() } override fun onMessage(webSocket: WebSocket, text: String) { handleIncoming(text) } override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { connected = false listeners.forEach { it.onDisconnected(t.message) } } override fun onClosed(webSocket: WebSocket, code: Int, reason: String) { connected = false listeners.forEach { it.onDisconnected(reason) } } }) } fun subscribe(destination: String) { val subscriptionId = synchronized(subscriptions) { if (subscriptions.containsKey(destination)) return val id = nextSubscriptionId() subscriptions[destination] = id id } if (connected) { sendSubscribe(destination, subscriptionId) } } fun unsubscribe(destination: String) { val subscriptionId = synchronized(subscriptions) { subscriptions.remove(destination) } ?: return if (connected) { sendFrame("UNSUBSCRIBE", mapOf("id" to subscriptionId), null) } } fun sendMessage( toId: String, chatType: String, msgType: String, content: String, mentionedUserIds: String? = null, ) { val payload = linkedMapOf( "appId" to appId, "toId" to toId, "chatType" to chatType, "msgType" to msgType, "content" to content, ) if (!mentionedUserIds.isNullOrBlank()) { payload["mentionedUserIds"] = mentionedUserIds } sendFrame( "SEND", mapOf( "destination" to "/app/chat.send", "content-type" to "application/json", ), gson.toJson(payload), ) } fun revokeMessage(messageId: String) { sendFrame( "SEND", mapOf( "destination" to "/app/chat.revoke", "content-type" to "application/json", ), gson.toJson( mapOf( "appId" to appId, "messageId" to messageId, ) ), ) } fun addListener(listener: ImEventListener) = listeners.add(listener) fun removeListener(listener: ImEventListener) = listeners.remove(listener) fun disconnect() { disconnect(closeSocket = true) } private fun sendConnectFrame() { connected = false sendFrame( "CONNECT", mapOf( "accept-version" to "1.2", "heart-beat" to "0,0", "host" to URI.create(wsUrl).host.orEmpty(), "Authorization" to "Bearer $token", ), null, ) } private fun handleIncoming(chunk: String) { if (chunk.isBlank()) return inboundBuffer.append(chunk) while (true) { val terminator = inboundBuffer.indexOf("\u0000") if (terminator < 0) return val frame = inboundBuffer.substring(0, terminator) inboundBuffer = StringBuilder(inboundBuffer.substring(terminator + 1)) if (frame.isNotBlank()) { handleFrame(frame) } } } private fun handleFrame(frame: String) { val parts = frame.split("\n\n", limit = 2) val headerLines = parts.firstOrNull().orEmpty().split("\n").filter { it.isNotBlank() } val command = headerLines.firstOrNull()?.trim().orEmpty() val headers = parseHeaders(headerLines.drop(1)) val body = parts.getOrNull(1).orEmpty() when (command.uppercase()) { "CONNECTED" -> { connected = true listeners.forEach { it.onConnected() } sendSubscribe("/user/queue/messages", nextSubscriptionId(prefix = "user")) val pendingSubscriptions = synchronized(subscriptions) { subscriptions.toMap() } pendingSubscriptions.forEach { (destination, id) -> if (destination != "/user/queue/messages") { sendSubscribe(destination, id) } } } "MESSAGE" -> { runCatching { val msg = gson.fromJson(body, ImMessage::class.java) if (msg.chatType.uppercase() == "GROUP") { listeners.forEach { it.onGroupMessage(msg) } } else { listeners.forEach { it.onMessage(msg) } } }.onFailure { e -> listeners.forEach { it.onError("Parse error: ${e.message}") } } } "ERROR" -> { val reason = body.ifBlank { headers["message"].orEmpty() } listeners.forEach { it.onError(reason.ifBlank { "STOMP error" }) } } } } private fun sendSubscribe(destination: String, subscriptionId: String) { sendFrame( "SUBSCRIBE", mapOf( "id" to subscriptionId, "destination" to destination, ), null, ) } private fun sendFrame(command: String, headers: Map, body: String?) { val socket = webSocket ?: return val frame = buildString { append(command).append('\n') headers.forEach { (key, value) -> append(escapeHeader(key)).append(':').append(escapeHeader(value)).append('\n') } append('\n') if (body != null) { append(body) } append('\u0000') } socket.send(frame) } private fun parseHeaders(lines: List): Map { val headers = linkedMapOf() lines.forEach { line -> val index = line.indexOf(':') if (index <= 0) return@forEach val key = unescapeHeader(line.substring(0, index)) val value = unescapeHeader(line.substring(index + 1)) headers[key] = value } return headers } private fun nextSubscriptionId(prefix: String = "sub"): String { subscriptionSeed += 1 return "$prefix-$subscriptionSeed" } private fun disconnect(closeSocket: Boolean) { connected = false synchronized(subscriptions) { subscriptions.clear() } inboundBuffer = StringBuilder() if (closeSocket) { webSocket?.close(1000, "User disconnect") } webSocket = null } private fun escapeHeader(value: String): String = value.replace("\\", "\\\\") .replace("\r", "\\r") .replace("\n", "\\n") .replace(":", "\\c") private fun unescapeHeader(value: String): String { val builder = StringBuilder() var index = 0 while (index < value.length) { val ch = value[index] if (ch == '\\' && index + 1 < value.length) { when (value[index + 1]) { 'r' -> builder.append('\r') 'n' -> builder.append('\n') 'c' -> builder.append(':') '\\' -> builder.append('\\') else -> { builder.append(value[index + 1]) } } index += 2 } else { builder.append(ch) index += 1 } } return builder.toString() } }