zl程序教程

您现在的位置是:首页 >  其他

当前栏目

分布式环境中spring cloud oauth2授权服务异常处理

2023-04-18 13:15:02 时间

环境

springboot 2.3.7 spring cloud 2.2.6 spring security 2.3.8 分布式部署多个spring security oauth2授权服务器实例,使用redis session同步会话

问题

客户端通过认证码模式获取令牌时会出现异常报错

分析

spring security oauth2 授权服务器默认使用InMemoryAuthorizationCodeServices 管理授权码,导致分布部署的多个授权服务没有同步授权码,负载均衡将获取令牌的请求发送到非登陆认证的服务器时将报错

解决

自定义RedisAuthorizationCodeServices,使用Redis集中管理分布式环境下的授权码

import org.apache.commons.lang3.SerializationUtils;
import org.psrframework.core.util.UUIDUtil;
import org.springframework.dao.DataAccessException;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.security.oauth2.common.exceptions.InvalidGrantException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.code.AuthorizationCodeServices;

import java.nio.charset.StandardCharsets;

public class RedisAuthorizationCodeServicesImpl implements AuthorizationCodeServices {
    private final String REDIS_KEY_AUTH_CODE_PREFIX = "auth_code:";
    private final RedisTemplate stringRedisTemplate;

    public RedisAuthorizationCodeServicesImpl(RedisTemplate redisTemplate) {
        this.stringRedisTemplate = redisTemplate;
    }

    @Override
    public String createAuthorizationCode(OAuth2Authentication authentication) {
        for (int i = 0; i < 10; i++) {
            String code = UUIDUtil.generateTimebasedUUID().toString();
            byte[] key = (REDIS_KEY_AUTH_CODE_PREFIX + code).getBytes(StandardCharsets.UTF_8);
            byte[] data = SerializationUtils.serialize(authentication);
            Boolean result = (Boolean) stringRedisTemplate.execute(
                    new RedisCallback<Boolean>() {
                        @Override
                        public Boolean doInRedis(RedisConnection connection) throws DataAccessException {
                            if (connection.setNX(key, data)) {
                                connection.expire(key, 1500);
                                return true;
                            }
                            return false;
                        }
                    }
            );
            if (Boolean.TRUE.equals(result)) {
                return code;
            }
        }
        return null;
    }

    @Override
    public OAuth2Authentication consumeAuthorizationCode(String code) throws InvalidGrantException {
        byte[] key = (REDIS_KEY_AUTH_CODE_PREFIX + code).getBytes(StandardCharsets.UTF_8);
        byte[] data = (byte[]) stringRedisTemplate.execute(new RedisCallback() {
            @Override
            public byte[] doInRedis(RedisConnection connection) throws DataAccessException {
                byte[] data = connection.get(key);
                connection.del(key);
                return data;
            }
        });
        return SerializationUtils.deserialize(data);
    }
}

注册RedisAuthorizationCodeServices

public class AuthorizationServerConfig extends AuthorizationServerConfigurerAdapter {
	...
	    @Override
    public void configure(AuthorizationServerEndpointsConfigurer endpoints) throws Exception {
        ...
        endpoints.authorizationCodeServices(new RedisAuthorizationCodeServicesImpl(redisTemplate));
    }
}

spring授权过程源码

  • 客户端 org.springframework.security.oauth2.client.OAuth2RestTemplate
protected OAuth2AccessToken acquireAccessToken(OAuth2ClientContext oauth2Context)
		throws UserRedirectRequiredException {

	AccessTokenRequest accessTokenRequest = oauth2Context.getAccessTokenRequest();
	if (accessTokenRequest == null) {
		throw new AccessTokenRequiredException(
				"No OAuth 2 security context has been established. Unable to access resource '"
						+ this.resource.getId() + "'.", resource);
	}

	// Transfer the preserved state from the (longer lived) context to the current request.
	String stateKey = accessTokenRequest.getStateKey();
	// 如果请求中没有状态码,则从上下文中获取预设的状态码(如果存在长生命周期的上下文,例如已获得授权的会话)
	if (stateKey != null) {
		accessTokenRequest.setPreservedState(oauth2Context.removePreservedState(stateKey));
	}

	OAuth2AccessToken existingToken = oauth2Context.getAccessToken();
	if (existingToken != null) {
		accessTokenRequest.setExistingToken(existingToken);
	}

	OAuth2AccessToken accessToken = null;
	accessToken = accessTokenProvider.obtainAccessToken(resource, accessTokenRequest);
	if (accessToken == null || accessToken.getValue() == null) {
		throw new IllegalStateException(
				"Access token provider returned a null access token, which is illegal according to the contract.");
	}
	oauth2Context.setAccessToken(accessToken);
	return accessToken;
}

org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeAccessTokenProvider

public OAuth2AccessToken obtainAccessToken(OAuth2ProtectedResourceDetails details, AccessTokenRequest request)
		throws UserRedirectRequiredException, UserApprovalRequiredException, AccessDeniedException,
		OAuth2AccessDeniedException {

	AuthorizationCodeResourceDetails resource = (AuthorizationCodeResourceDetails) details;

	// 如果请求中不存在授权码和状态码,则跳转到授权页面
	if (request.getAuthorizationCode() == null) {
		if (request.getStateKey() == null) {
			throw getRedirectForAuthorization(resource, request);
		}
		obtainAuthorizationCode(resource, request);
	}
	// 根据授权码和状态码获取令牌
	return retrieveToken(request, resource, getParametersForTokenRequest(resource, request),
			getHeadersForTokenRequest(request));

}

protected OAuth2AccessToken retrieveToken(AccessTokenRequest request, OAuth2ProtectedResourceDetails resource,
		MultiValueMap<String, String> form, HttpHeaders headers) throws OAuth2AccessDeniedException {

	try {
		// Prepare headers and form before going into rest template call in case the URI is affected by the result
		authenticationHandler.authenticateTokenRequest(resource, form, headers);
		// Opportunity to customize form and headers
		tokenRequestEnhancer.enhance(request, resource, form, headers);
		final AccessTokenRequest copy = request;

		final ResponseExtractor<OAuth2AccessToken> delegate = getResponseExtractor();
		ResponseExtractor<OAuth2AccessToken> extractor = new ResponseExtractor<OAuth2AccessToken>() {
			@Override
			public OAuth2AccessToken extractData(ClientHttpResponse response) throws IOException {
				if (response.getHeaders().containsKey("Set-Cookie")) {
					copy.setCookie(response.getHeaders().getFirst("Set-Cookie"));
				}
				return delegate.extractData(response);
			}
		};
		return getRestTemplate().execute(getAccessTokenUri(resource, form), getHttpMethod(),
				getRequestCallback(resource, form, headers), extractor , form.toSingleValueMap());

	}
	catch (OAuth2Exception oe) {
		throw new OAuth2AccessDeniedException("Access token denied.", resource, oe);
	}
	catch (RestClientException rce) {
		throw new OAuth2AccessDeniedException("Error requesting access token.", resource, rce);
	}

}
  • 认证服务器 org.springframework.security.oauth2.provider.code.AuthorizationCodeTokenGranter
@Override
protected OAuth2Authentication getOAuth2Authentication(ClientDetails client, TokenRequest tokenRequest) {

	Map<String, String> parameters = tokenRequest.getRequestParameters();
	String authorizationCode = parameters.get("code");
	String redirectUri = parameters.get(OAuth2Utils.REDIRECT_URI);

	if (authorizationCode == null) {
		throw new InvalidRequestException("An authorization code must be supplied.");
	}
    // 从授权码服务中根据授权码获取认证信息
	OAuth2Authentication storedAuth = authorizationCodeServices.consumeAuthorizationCode(authorizationCode);
	if (storedAuth == null) {
		throw new InvalidGrantException("Invalid authorization code: " + authorizationCode);
	}

	OAuth2Request pendingOAuth2Request = storedAuth.getOAuth2Request();
	// https://jira.springsource.org/browse/SECOAUTH-333
	// This might be null, if the authorization was done without the redirect_uri parameter
	String redirectUriApprovalParameter = pendingOAuth2Request.getRequestParameters().get(
			OAuth2Utils.REDIRECT_URI);

	if ((redirectUri != null || redirectUriApprovalParameter != null)
			&& !pendingOAuth2Request.getRedirectUri().equals(redirectUri)) {
		throw new RedirectMismatchException("Redirect URI mismatch.");
	}

	String pendingClientId = pendingOAuth2Request.getClientId();
	String clientId = tokenRequest.getClientId();
	if (clientId != null && !clientId.equals(pendingClientId)) {
		// just a sanity check.
		throw new InvalidClientException("Client ID mismatch");
	}

	// Secret is not required in the authorization request, so it won't be available
	// in the pendingAuthorizationRequest. We do want to check that a secret is provided
	// in the token request, but that happens elsewhere.

	Map<String, String> combinedParameters = new HashMap<String, String>(pendingOAuth2Request
			.getRequestParameters());
	// Combine the parameters adding the new ones last so they override if there are any clashes
	combinedParameters.putAll(parameters);
	
	// Make a new stored request with the combined parameters
	OAuth2Request finalStoredOAuth2Request = pendingOAuth2Request.createOAuth2Request(combinedParameters);
	
	Authentication userAuth = storedAuth.getUserAuthentication();
	
	return new OAuth2Authentication(finalStoredOAuth2Request, userAuth);

}

org.springframework.security.oauth2.config.annotation.web.configurers.AuthorizationServerEndpointsConfigurer

// 默认使用内存授权码服务
private AuthorizationCodeServices authorizationCodeServices() {
	if (authorizationCodeServices == null) {
		authorizationCodeServices = new InMemoryAuthorizationCodeServices();
	}
	return authorizationCodeServices;
}

org.springframework.security.oauth2.provider.code.InMemoryAuthorizationCodeServices

public class InMemoryAuthorizationCodeServices extends RandomValueAuthorizationCodeServices {
	// 授权码保存在本地内存中
	protected final ConcurrentHashMap<String, OAuth2Authentication> authorizationCodeStore = new ConcurrentHashMap<String, OAuth2Authentication>();

	@Override
	protected void store(String code, OAuth2Authentication authentication) {
		this.authorizationCodeStore.put(code, authentication);
	}

	@Override
	public OAuth2Authentication remove(String code) {
		OAuth2Authentication auth = this.authorizationCodeStore.remove(code);
		return auth;
	}

}