package com.youlai.boot.message.registry;
import com.youlai.boot.message.dto.OnlineUserDTO;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.event.ContextClosedEvent;
import org.springframework.context.event.EventListener;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
/**
* SSE 会话注册表
*
* 维护 SSE 连接的用户会话信息,支持多设备同时登录
*/
@Slf4j
@Component
public class SseSessionRegistry {
/** 用户名 -> SseEmitter 集合(支持多设备) */
private final Map> userEmittersMap = new ConcurrentHashMap<>();
/** SseEmitter -> 用户名(快速定位用户) */
private final Map emitterUserMap = new ConcurrentHashMap<>();
/** SseEmitter -> 连接时间 */
private final Map emitterTimeMap = new ConcurrentHashMap<>();
/**
* 用户上线(建立 SSE 连接)
*
* @param username 用户名
* @param emitter SseEmitter
*/
public void userConnected(String username, SseEmitter emitter) {
userEmittersMap.computeIfAbsent(username, k -> ConcurrentHashMap.newKeySet()).add(emitter);
emitterUserMap.put(emitter, username);
emitterTimeMap.put(emitter, System.currentTimeMillis());
log.debug("用户[{}]SSE连接已建立", username);
// 设置连接超时和完成回调
emitter.onCompletion(() -> {
removeEmitter(emitter);
log.debug("用户[{}]SSE连接已完成", username);
});
emitter.onTimeout(() -> {
removeEmitter(emitter);
log.debug("用户[{}]SSE连接超时", username);
});
emitter.onError(e -> {
removeEmitter(emitter);
log.debug("用户[{}]SSE连接错误: {}", username, e.getMessage());
});
}
/**
* 移除指定 emitter
*/
private void removeEmitter(SseEmitter emitter) {
String username = emitterUserMap.remove(emitter);
if (username == null) {
return;
}
emitterTimeMap.remove(emitter);
Set emitters = userEmittersMap.get(username);
if (emitters != null) {
emitters.remove(emitter);
if (emitters.isEmpty()) {
userEmittersMap.remove(username);
log.debug("用户[{}]所有SSE连接已断开", username);
}
}
}
/**
* 用户下线(断开所有 SSE 连接)
*
* @param username 用户名
*/
public void userDisconnected(String username) {
Set emitters = userEmittersMap.remove(username);
if (emitters == null) {
return;
}
emitters.forEach(emitter -> {
emitterUserMap.remove(emitter);
emitterTimeMap.remove(emitter);
try {
emitter.complete();
} catch (Exception ignored) {
}
});
log.debug("用户[{}]已下线,移除{}个SSE连接", username, emitters.size());
}
/**
* 获取在线用户数量
*/
public int getOnlineUserCount() {
return userEmittersMap.size();
}
/**
* 获取总连接数量(包括多设备)
*/
public int getTotalConnectionCount() {
return emitterUserMap.size();
}
/**
* 获取指定用户的连接数量
*/
public int getUserConnectionCount(String username) {
Set emitters = userEmittersMap.get(username);
return emitters != null ? emitters.size() : 0;
}
/**
* 检查用户是否在线
*/
public boolean isUserOnline(String username) {
Set emitters = userEmittersMap.get(username);
return emitters != null && !emitters.isEmpty();
}
/**
* 获取所有在线用户列表
*/
public List getOnlineUsers() {
return userEmittersMap.entrySet().stream()
.map(entry -> {
String username = entry.getKey();
Set emitters = entry.getValue();
long earliestTime = emitters.stream()
.map(emitterTimeMap::get)
.filter(t -> t != null)
.mapToLong(Long::longValue)
.min()
.orElse(System.currentTimeMillis());
return new OnlineUserDTO(username, emitters.size(), earliestTime);
})
.collect(Collectors.toList());
}
/**
* 获取所有活跃的 SseEmitter
*/
public Set getAllEmitters() {
return emitterUserMap.keySet();
}
/**
* 获取指定用户的所有 SseEmitter
*/
public Set getUserEmitters(String username) {
return userEmittersMap.get(username);
}
/**
* 向指定 emitter 发送事件
*/
public boolean sendEvent(SseEmitter emitter, String eventName, Object data) {
try {
emitter.send(SseEmitter.event()
.name(eventName)
.data(data));
return true;
} catch (IOException e) {
log.warn("发送SSE事件失败: {}", e.getMessage());
removeEmitter(emitter);
return false;
}
}
/**
* 向所有连接广播事件
*/
public void broadcast(String eventName, Object data) {
getAllEmitters().forEach(emitter -> sendEvent(emitter, eventName, data));
}
/**
* 向指定用户发送事件
*/
public void sendToUser(String username, String eventName, Object data) {
Set emitters = userEmittersMap.get(username);
if (emitters != null) {
emitters.forEach(emitter -> sendEvent(emitter, eventName, data));
}
}
/**
* 心跳检测:每30秒向所有连接发送ping事件,及时清理已断开的僵尸连接
*/
@Scheduled(fixedRate = 30000)
public void heartbeat() {
if (emitterUserMap.isEmpty()) {
return;
}
List failedEmitters = new ArrayList<>();
for (SseEmitter emitter : emitterUserMap.keySet()) {
try {
emitter.send(SseEmitter.event().name("ping").data("heartbeat"));
} catch (Exception e) {
failedEmitters.add(emitter);
}
}
if (!failedEmitters.isEmpty()) {
log.debug("心跳检测清理{}个失效SSE连接", failedEmitters.size());
failedEmitters.forEach(this::removeEmitter);
}
}
/**
* 容器关闭时主动断开所有 SSE 连接,避免阻塞应用停止
*/
@Order(Ordered.HIGHEST_PRECEDENCE)
@EventListener(ContextClosedEvent.class)
public void destroy() {
int count = emitterUserMap.size();
if (count == 0) {
return;
}
log.info("应用关闭,主动断开 {} 个SSE连接...", count);
emitterUserMap.keySet().forEach(emitter -> {
try {
emitter.complete();
} catch (Exception ignored) {
}
});
userEmittersMap.clear();
emitterUserMap.clear();
emitterTimeMap.clear();
log.info("所有SSE连接已断开");
}
}