package org.datasurvey.config; import java.security.Principal; import java.util.*; import org.datasurvey.security.AuthoritiesConstants; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.http.server.*; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.config.annotation.*; import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import tech.jhipster.config.JHipsterProperties; @Configuration @EnableWebSocketMessageBroker public class WebsocketConfiguration implements WebSocketMessageBrokerConfigurer { public static final String IP_ADDRESS = "IP_ADDRESS"; private final JHipsterProperties jHipsterProperties; public WebsocketConfiguration(JHipsterProperties jHipsterProperties) { this.jHipsterProperties = jHipsterProperties; } @Override public void configureMessageBroker(MessageBrokerRegistry config) { config.enableSimpleBroker("/topic"); } @Override public void registerStompEndpoints(StompEndpointRegistry registry) { String[] allowedOrigins = Optional .ofNullable(jHipsterProperties.getCors().getAllowedOrigins()) .map(origins -> origins.toArray(new String[0])) .orElse(new String[0]); registry .addEndpoint("/websocket/tracker") .setHandshakeHandler(defaultHandshakeHandler()) .setAllowedOrigins(allowedOrigins) .withSockJS() .setInterceptors(httpSessionHandshakeInterceptor()); } @Bean public HandshakeInterceptor httpSessionHandshakeInterceptor() { return new HandshakeInterceptor() { @Override public boolean beforeHandshake( ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes ) throws Exception { if (request instanceof ServletServerHttpRequest) { ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request; attributes.put(IP_ADDRESS, servletRequest.getRemoteAddress()); } return true; } @Override public void afterHandshake( ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception ) {} }; } private DefaultHandshakeHandler defaultHandshakeHandler() { return new DefaultHandshakeHandler() { @Override protected Principal determineUser(ServerHttpRequest request, WebSocketHandler wsHandler, Map attributes) { Principal principal = request.getPrincipal(); if (principal == null) { Collection authorities = new ArrayList<>(); authorities.add(new SimpleGrantedAuthority(AuthoritiesConstants.ANONYMOUS)); principal = new AnonymousAuthenticationToken("WebsocketConfiguration", "anonymous", authorities); } return principal; } }; } }