Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize: add jwt authentication for RegisterXXRequests #6317

Open
wants to merge 21 commits into
base: 2.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions core/src/main/java/org/apache/seata/core/auth/JwtAuthManager.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.seata.core.auth;



import org.apache.seata.common.util.StringUtils;
import org.apache.seata.config.ConfigurationFactory;

import java.util.HashMap;
import java.util.Map;

import static org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_KV_CHAR;
import static org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_SPLIT_CHAR;


public class JwtAuthManager {
private String accessToken;

private String username;

private String password;

public final static String PRO_USERNAME = "username";
xingfudeshi marked this conversation as resolved.
Show resolved Hide resolved

public final static String PRO_PASSWORD = "password";

public final static String PRO_TOKEN = "token";

private static JwtAuthManager instance;

private JwtAuthManager() {
}

public static JwtAuthManager getInstance() {
if (instance == null) {
instance = new JwtAuthManager();
}
return instance;
}

public void init() {
username = ConfigurationFactory.CURRENT_FILE_INSTANCE.getConfig("security." + PRO_USERNAME);
xingfudeshi marked this conversation as resolved.
Show resolved Hide resolved
password = ConfigurationFactory.CURRENT_FILE_INSTANCE.getConfig("security." + PRO_PASSWORD);
}

public String getToken() {
return accessToken;
}

public String getUsername() {
return username;
}

public void setUsername(String username) {
this.username = username;
}

public String getPassword() {
return password;
}

public void setPassword(String password) {
this.password = password;
}

public void refreshToken(String newToken) {
accessToken = newToken;
}

public void setAccessToken(String token) {
accessToken = token;
}

public static HashMap<String, String> convertToHashMap(String inputString) {
HashMap<String, String> resultMap = new HashMap<>();
if (StringUtils.isBlank(inputString)) {
return resultMap;
}
String[] keyValuePairs = inputString.split(EXTRA_DATA_SPLIT_CHAR);
for (String pair : keyValuePairs) {
String[] keyValue = pair.trim().split(EXTRA_DATA_KV_CHAR);
if (keyValue.length == 2) {
resultMap.put(keyValue[0].trim(), keyValue[1].trim());
}
}
return resultMap;
}

public static String convertToString(HashMap<String, String> inputMap) {
if (inputMap == null || inputMap.isEmpty()) {
return "";
}
StringBuilder resultString = new StringBuilder();
for (Map.Entry<String, String> entry : inputMap.entrySet()) {
String key = entry.getKey();
String value = entry.getValue();
String pair = key + EXTRA_DATA_KV_CHAR + value + EXTRA_DATA_SPLIT_CHAR;
resultString.append(pair);
}
if (resultString.length() > 0) {
resultString.deleteCharAt(resultString.length() - 1);
}
return resultString.toString();
}
public static String refreshAuthData(String extraData) {
HashMap<String,String> extraDataMap = convertToHashMap(extraData);
extraDataMap.remove(PRO_TOKEN);
if(null != getInstance().getToken()){
extraDataMap.put(PRO_TOKEN,getInstance().getToken());
}else if(null!= getInstance().getUsername() && null != getInstance().getPassword()){
extraDataMap.put(PRO_USERNAME,getInstance().getUsername());
extraDataMap.put(PRO_PASSWORD,getInstance().getPassword());
}
return convertToString(extraDataMap);
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ public RegisterRMRequest(String applicationId, String transactionServiceGroup) {
super(applicationId, transactionServiceGroup);
}

/**
* Instantiates a new Register rm request.
*
* @param applicationId the application id
* @param transactionServiceGroup the transaction service group
* @param extraData the extra data
*/
public RegisterRMRequest(String applicationId, String transactionServiceGroup, String extraData) {
super(applicationId, transactionServiceGroup, extraData);
}

private String resourceIds;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ public enum ResultCode {
* Success result code.
*/
// Success
Success;
Success,

/**
* Retry result code.
*/
// Retry
Retry;

/**
* Get result code.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.seata.core.rpc;

import org.apache.seata.common.exception.RetryableException;
import org.apache.seata.core.protocol.AbstractIdentifyRequest;
import org.apache.seata.core.protocol.RegisterRMRequest;
import org.apache.seata.core.protocol.RegisterTMRequest;

Expand All @@ -31,13 +33,19 @@ public interface RegisterCheckAuthHandler {
* @param request the request
* @return the boolean
*/
boolean regTransactionManagerCheckAuth(RegisterTMRequest request);
boolean regTransactionManagerCheckAuth(RegisterTMRequest request) throws RetryableException;

/**
* Reg resource manager check auth boolean.
*
* @param request the request
* @return the boolean
*/
boolean regResourceManagerCheckAuth(RegisterRMRequest request);
boolean regResourceManagerCheckAuth(RegisterRMRequest request) throws RetryableException;

/**
* Refresh token
* @return the String
*/
String refreshAuthToken(AbstractIdentifyRequest abstractIdentifyRequest) ;
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
Expand All @@ -29,6 +30,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;

import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandler.Sharable;
Expand All @@ -43,6 +45,7 @@
import org.apache.seata.common.util.CollectionUtils;
import org.apache.seata.common.util.NetUtil;
import org.apache.seata.common.util.StringUtils;
import org.apache.seata.core.auth.JwtAuthManager;
import org.apache.seata.core.protocol.AbstractMessage;
import org.apache.seata.core.protocol.HeartbeatMessage;
import org.apache.seata.core.protocol.MergeMessage;
Expand All @@ -63,11 +66,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_KV_CHAR;
import static org.apache.seata.common.ConfigurationKeys.EXTRA_DATA_SPLIT_CHAR;
import static org.apache.seata.common.exception.FrameworkErrorCode.NoAvailableService;

/**
* The netty remoting client.
*
*/
public abstract class AbstractNettyRemotingClient extends AbstractNettyRemoting implements RemotingClient {

Expand All @@ -91,7 +95,7 @@ public abstract class AbstractNettyRemotingClient extends AbstractNettyRemoting

/**
* When batch sending is enabled, the message will be stored to basketMap
* Send via asynchronous thread {@link AbstractNettyRemotingClient.MergedSendRunnable}
* Send via asynchronous thread {@link org.apache.seata.core.rpc.netty.AbstractNettyRemotingClient.MergedSendRunnable}
* {@link AbstractNettyRemotingClient#isEnableClientBatchSendRequest()}
*/
protected final ConcurrentHashMap<String/*serverAddress*/, BlockingQueue<RpcMessage>> basketMap = new ConcurrentHashMap<>();
Expand All @@ -100,10 +104,12 @@ public abstract class AbstractNettyRemotingClient extends AbstractNettyRemoting
private final NettyPoolKey.TransactionRole transactionRole;
private ExecutorService mergeSendExecutorService;
private TransactionMessageHandler transactionMessageHandler;
protected JwtAuthManager jwtAuthManager = JwtAuthManager.getInstance();
protected volatile boolean enableClientBatchSendRequest;

@Override
public void init() {
jwtAuthManager.init();
timerExecutor.scheduleAtFixedRate(() -> {
try {
clientChannelManager.reconnect(getTransactionServiceGroup());
Expand Down Expand Up @@ -172,7 +178,7 @@ public Object sendSyncRequest(Object msg) throws TimeoutException {
} catch (Exception exx) {
LOGGER.error("wait response error:{},ip:{},request:{}", exx.getMessage(), serverAddress, rpcMessage.getBody());
if (exx instanceof TimeoutException) {
throw (TimeoutException)exx;
throw (TimeoutException) exx;
} else {
throw new RuntimeException(exx);
}
Expand Down Expand Up @@ -295,6 +301,21 @@ protected String getXid(Object msg) {
return StringUtils.isBlank(xid) ? String.valueOf(ThreadLocalRandom.current().nextLong(Long.MAX_VALUE)) : xid;
}

protected String getAuthData() {
return JwtAuthManager.refreshAuthData(null);
}

protected void refreshAuthToken(String extraData) {
if (StringUtils.isBlank(extraData)) {
return;
}
HashMap<String, String> authData = JwtAuthManager.convertToHashMap(extraData);
String newToken = authData.get("newToken");
if (StringUtils.isNotBlank(newToken)) {
jwtAuthManager.refreshToken(newToken);
}
}

private String getThreadPrefix() {
return AbstractNettyRemotingClient.MERGE_THREAD_PREFIX + THREAD_PREFIX_SPLIT_CHAR + transactionRole.name();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@
import io.netty.channel.Channel;
import org.apache.seata.common.exception.FrameworkException;
import org.apache.seata.common.util.NetUtil;
import org.apache.seata.core.auth.JwtAuthManager;
import org.apache.seata.core.protocol.AbstractIdentifyRequest;
import org.apache.seata.core.protocol.RegisterRMRequest;
import org.apache.seata.core.protocol.RegisterTMRequest;
import org.apache.seata.core.protocol.RegisterRMResponse;
import org.apache.seata.core.protocol.RegisterTMResponse;
import org.apache.seata.core.protocol.ResultCode;
import org.apache.commons.pool.KeyedPoolableObjectFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* The type Netty key poolable factory.
*
*/
public class NettyPoolableFactory implements KeyedPoolableObjectFactory<NettyPoolKey, Channel> {

Expand Down Expand Up @@ -64,6 +68,19 @@ public Channel makeObject(NettyPoolKey key) {
}
try {
response = rpcRemotingClient.sendSyncRequest(tmpChannel, key.getMessage());
if (isRegisterExpired(response, key.getTransactionRole())) {
// relogin to get token
JwtAuthManager.getInstance().refreshToken(null);
AbstractIdentifyRequest request;
if (key.getTransactionRole().equals(NettyPoolKey.TransactionRole.TMROLE)) {
request = (RegisterTMRequest) key.getMessage();
} else {
request = (RegisterRMRequest) key.getMessage();
}
String identifyExtraData = JwtAuthManager.refreshAuthData(request.getExtraData());
request.setExtraData(identifyExtraData);
response = rpcRemotingClient.sendSyncRequest(tmpChannel, request);
}
if (!isRegisterSuccess(response, key.getTransactionRole())) {
rpcRemotingClient.onRegisterMsgFail(key.getAddress(), tmpChannel, response, key.getMessage());
} else {
Expand All @@ -85,6 +102,26 @@ public Channel makeObject(NettyPoolKey key) {
return channelToServer;
}

private boolean isRegisterExpired(Object response, NettyPoolKey.TransactionRole transactionRole) {
if (response == null) {
return false;
}
if (transactionRole.equals(NettyPoolKey.TransactionRole.TMROLE)) {
if (!(response instanceof RegisterTMResponse)) {
return false;
}
RegisterTMResponse registerTMResponse = (RegisterTMResponse) response;
return registerTMResponse.getResultCode().equals(ResultCode.Retry);
} else if (transactionRole.equals(NettyPoolKey.TransactionRole.RMROLE)) {
if (!(response instanceof RegisterRMResponse)) {
return false;
}
RegisterRMResponse registerRMResponse = (RegisterRMResponse) response;
return registerRMResponse.getResultCode().equals(ResultCode.Retry);
}
return false;
}

private boolean isRegisterSuccess(Object response, NettyPoolKey.TransactionRole transactionRole) {
if (response == null) {
return false;
Expand All @@ -93,13 +130,13 @@ private boolean isRegisterSuccess(Object response, NettyPoolKey.TransactionRole
if (!(response instanceof RegisterTMResponse)) {
return false;
}
RegisterTMResponse registerTMResponse = (RegisterTMResponse)response;
RegisterTMResponse registerTMResponse = (RegisterTMResponse) response;
return registerTMResponse.isIdentified();
} else if (transactionRole.equals(NettyPoolKey.TransactionRole.RMROLE)) {
if (!(response instanceof RegisterRMResponse)) {
return false;
}
RegisterRMResponse registerRMResponse = (RegisterRMResponse)response;
RegisterRMResponse registerRMResponse = (RegisterRMResponse) response;
return registerRMResponse.isIdentified();
}
return false;
Expand Down
Loading
Loading