Skip to content

Instantly share code, notes, and snippets.

@vuuvv
Last active April 16, 2020 01:38
Show Gist options
  • Save vuuvv/e8b6bcea0b527305fed831769d66c4fc to your computer and use it in GitHub Desktop.
Save vuuvv/e8b6bcea0b527305fed831769d66c4fc to your computer and use it in GitHub Desktop.
spring cloud gateway modify response
package com.vuuvv.vmall.gateway.filter;
import cn.hutool.core.util.StrUtil;
import com.fasterxml.jackson.core.type.TypeReference;
import com.vuuvv.common.constant.ErrorCode;
import com.vuuvv.common.model.ApiResponse;
import com.vuuvv.common.utils.JwtUtils;
import com.vuuvv.jdbcplus.utils.JsonUtils;
import com.vuuvv.vmall.user.api.constant.Constants;
import com.vuuvv.vmall.user.api.constant.RouteAccess;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.io.IOUtils;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.core.Ordered;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Date;
import java.util.Objects;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.addOriginalRequestUrl;
/**
* Created by vuuvv on 2020-04-15
*/
@Slf4j
public class AclGatewayFilter implements GatewayFilter, Ordered {
private final AclGatewayFilterFactory.Config config;
private final AclGatewayFilterFactory factory;
private final WebClient.Builder clientBuilder;
AclGatewayFilter(AclGatewayFilterFactory.Config config, AclGatewayFilterFactory factory, WebClient.Builder clientBuilder) {
this.config = config;
this.factory = factory;
this.clientBuilder = clientBuilder;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
return Mono.just(factory.getJwtSecret())
.flatMap(secret -> {
if (StringUtils.isEmpty(secret)) {
return getSecret();
} else {
return Mono.just(secret);
}
})
.flatMap(secret -> {
val userId = getUserId(exchange, secret);
if (Objects.equals(config.getType(), RouteAccess.LOGIN.toString())) {
return login(exchange, chain, config, userId);
} else if (Objects.equals(config.getType(), RouteAccess.GUARD.toString())) {
return guard(exchange, chain, config, userId);
} else {
return forward(exchange, chain, config, userId);
}
});
}
@Override
public int getOrder() {
return -2;
}
private Mono<Void> login(ServerWebExchange exchange, GatewayFilterChain chain, AclGatewayFilterFactory.Config config, String userId) {
if (StringUtils.isEmpty(userId)) {
return refresh(exchange, chain, config, userId);
}
return forward(exchange, chain, config, userId);
}
private Mono<Void> guard(ServerWebExchange exchange, GatewayFilterChain chain, AclGatewayFilterFactory.Config config, String userId) {
if (StringUtils.isEmpty(userId)) {
return refresh(exchange, chain, config, userId);
}
return forward(exchange, chain, config, userId);
}
/**
* 刷新jwt token
*/
private Mono<Void> refresh(ServerWebExchange exchange, GatewayFilterChain chain, AclGatewayFilterFactory.Config config, String userId) {
val refreshTokenCookie = exchange.getRequest().getCookies().getFirst(Constants.Cookie.REFRESH_TOKEN);
if (StringUtils.isEmpty(refreshTokenCookie) || StringUtils.isEmpty(refreshTokenCookie.getValue())) {
return loginErrorResponse(exchange);
}
return clientBuilder.build().post().uri("lb://vmall-user/users/refresh")
.cookie(Constants.Cookie.REFRESH_TOKEN, refreshTokenCookie.getValue())
.exchange()
.flatMap(resp -> {
if (resp.statusCode().is2xxSuccessful()) {
return forwardWithRefreshToken(resp, exchange, chain, config, userId);
} else {
return loginErrorResponse(exchange);
}
});
}
/**
* 返回登录错误
*/
private Mono<Void> loginErrorResponse(ServerWebExchange exchange) {
val response = exchange.getResponse();
val buffer = JsonUtils.stringify(ApiResponse.error(ErrorCode.NOT_LOGIN)).getBytes(StandardCharsets.UTF_8);
val dataBuffer = response.bufferFactory().allocateBuffer().write(buffer);
return response.writeWith(Mono.just(dataBuffer));
}
private ServerHttpRequest generateForwardRequest(ServerWebExchange exchange, AclGatewayFilterFactory.Config config, String userId) {
String replacement = config.getReplacement().replace("$\\", "$");
ServerHttpRequest req = exchange.getRequest();
addOriginalRequestUrl(exchange, req.getURI());
String path = req.getURI().getRawPath();
String newPath = StrUtil.isEmpty(config.getRegexp()) ? replacement :
path.replaceAll(config.getRegexp(), replacement);
ServerHttpRequest request = req.mutate().header(Constants.Head.UserId, new String[] { userId }).path(newPath).build();
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, request.getURI());
return request;
}
private ServerHttpResponse generateForwardResponse(ServerWebExchange exchange, String jwtToken) {
BodyHandlerFunction bodyHandler = (resp, body) -> Flux.from(body)
.map(dataBuffer -> {
//响应信息转换为字符串
String reqBody = null;
try {
//dataBuffer 转换为String
reqBody = IOUtils
.toString(dataBuffer.asInputStream(), "UTF-8");
} catch (IOException e) {
log.warn("forward", e);
}
return reqBody;
}).flatMap(json -> {
String rBody = json;
if (!StringUtils.isEmpty(json)) {
try {
val r = JsonUtils.parse(json, new TypeReference<ApiResponse<Object>>() {});
if (!StringUtils.isEmpty(r)) {
r.setAuth(jwtToken);
rBody = JsonUtils.stringify(r);
}
} catch (Exception e) {
log.warn("forward", e);
}
}
HttpHeaders headers = resp.getHeaders();
headers.setContentLength(rBody.length());
return resp.writeWith(Flux.just(rBody)
.map(bx -> resp.bufferFactory()
.wrap(bx.getBytes())));
}).then();
return new BodyHandlerServerHttpResponseDecorator(
bodyHandler, exchange.getResponse());
}
/**
* 带refresh token的转发
*/
private Mono<Void> forwardWithRefreshToken(ClientResponse resp, ServerWebExchange exchange, GatewayFilterChain chain, AclGatewayFilterFactory.Config config, String userId) {
val cookie = resp.cookies().getFirst(Constants.Cookie.REFRESH_TOKEN);
if (cookie == null) {
return forward(exchange, chain, config, userId);
}
exchange.getResponse().addCookie(cookie);
return resp.bodyToMono(String.class)
.flatMap(body -> {
val r = JsonUtils.parse(body, new TypeReference<ApiResponse<Object>>() {});
if (r == null || !r.isSuccess()) {
return loginErrorResponse(exchange);
}
exchange.getResponse().getHeaders().add(Constants.Head.Authorization, r.getAuth());
return chain.filter(
exchange.mutate()
.request(generateForwardRequest(exchange, config, userId))
// .response(generateForwardResponse(exchange, r.getAuth()))
.build()
);
});
}
/**
* 转发到对应的微服务
*/
private Mono<Void> forward(ServerWebExchange exchange, GatewayFilterChain chain, AclGatewayFilterFactory.Config config, String userId) {
return chain.filter(exchange.mutate().request(
generateForwardRequest(exchange, config, userId)
).build());
}
/**
* 获取jwt密钥
*/
private Mono<String> getSecret() {
return clientBuilder.build().get()
.uri("lb://vmall-user/users/jwt")
.exchange()
.flatMap(resp ->
resp.statusCode().is2xxSuccessful() ?
resp.bodyToMono(String.class) :
Mono.just("")
).map(resp -> {
factory.setJwtSecret(resp);
return resp;
});
}
private String getUserId(ServerWebExchange exchange, String jwtSecret) {
val token = exchange.getRequest().getHeaders().getFirst(Constants.Head.Authorization);
if (StringUtils.isEmpty(token)) {
return "";
}
try {
val claim = JwtUtils.getClaimsFromToken(token, jwtSecret);
// 过期
if (claim.getExpiration().before(new Date())) {
return "";
}
return String.valueOf(claim.get(JwtUtils.USER_ID));
} catch (Exception e) {
log.warn("jwt token", e);
return "";
}
}
}
package com.vuuvv.vmall.gateway.filter;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.reactivestreams.Publisher;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.filter.factory.rewrite.CachedBodyOutputMessage;
import org.springframework.cloud.gateway.support.BodyInserterContext;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import javax.annotation.Resource;
import java.util.Arrays;
import java.util.List;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.ORIGINAL_RESPONSE_CONTENT_TYPE_ATTR;
/**
* Created by vuuvv on 2019-09-17
*/
@Component
@SuppressWarnings({"WeakerAccess", "unused"})
@Slf4j
public class AclGatewayFilterFactory
extends AbstractGatewayFilterFactory<AclGatewayFilterFactory.Config> {
/**
* Name id.
*/
public static final String KEY_ID = "id";
/**
* Name key.
*/
public static final String KEY_NAME = "name";
/**
* Type key.
*/
public static final String KEY_TYPE = "type";
/**
* Regexp key.
*/
public static final String KEY_REGEXP = "regexp";
/**
* Replacement key.
*/
public static final String KEY_REPLACEMENT = "replacement";
private String jwtSecret = "";
@Resource
private WebClient.Builder clientBuilder;
AclGatewayFilterFactory() {
super(Config.class);
}
@Override
public List<String> shortcutFieldOrder() {
return Arrays.asList(KEY_ID, KEY_TYPE, KEY_NAME, KEY_REGEXP, KEY_REPLACEMENT);
}
@Override
public GatewayFilter apply(Config config) {
return new AclGatewayFilter(config, this, clientBuilder);
}
public String getJwtSecret() {
return jwtSecret;
}
public void setJwtSecret(String jwtSecret) {
this.jwtSecret = jwtSecret;
}
@SuppressWarnings("unchecked")
ServerHttpResponse decorate(ServerWebExchange exchange) {
return new ServerHttpResponseDecorator(exchange.getResponse()) {
@Override
public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
String originalResponseContentType = exchange
.getAttribute(ORIGINAL_RESPONSE_CONTENT_TYPE_ATTR);
HttpHeaders httpHeaders = new HttpHeaders();
// explicitly add it in this way instead of
// 'httpHeaders.setContentType(originalResponseContentType)'
// this will prevent exception in case of using non-standard media
// types like "Content-Type: image"
httpHeaders.add(HttpHeaders.CONTENT_TYPE,
originalResponseContentType);
ClientResponse clientResponse = ClientResponse
.create(exchange.getResponse().getStatusCode())
.headers(headers -> headers.putAll(httpHeaders))
.body(Flux.from(body)).build();
// TODO: flux or mono
Mono modifiedBody = clientResponse.bodyToMono(String.class);
BodyInserter bodyInserter = BodyInserters.fromPublisher(modifiedBody,
String.class);
CachedBodyOutputMessage outputMessage = new CachedBodyOutputMessage(
exchange, exchange.getResponse().getHeaders());
return bodyInserter.insert(outputMessage, new BodyInserterContext())
.then(Mono.defer(() -> {
Flux<DataBuffer> messageBody = outputMessage.getBody();
HttpHeaders headers = getDelegate().getHeaders();
if (!headers.containsKey(HttpHeaders.TRANSFER_ENCODING)) {
messageBody = messageBody.doOnNext(data -> headers
.setContentLength(data.readableByteCount()));
}
// TODO: fail if isStreamingMediaType?
return getDelegate().writeWith(messageBody);
}));
}
@Override
public Mono<Void> writeAndFlushWith(
Publisher<? extends Publisher<? extends DataBuffer>> body) {
return writeWith(Flux.from(body).flatMapSequential(p -> p));
}
};
}
@Getter
@Setter
public static class Config {
/**
* route id
*/
private String id;
/**
* 守护的类型, anonymous: 匿名访问,login: 需登录访问, guard: 需权限访问
*/
private String type;
/**
* 路由的名称,也可以理解为权限的名称
*/
private String name;
/**
* 替换的正则表达式,如果为空,则直接替换成replacement
*/
private String regexp;
/**
* 替换后的路径
*/
private String replacement;
}
}
package com.vuuvv.vmall.gateway.filter;
import org.reactivestreams.Publisher;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.server.reactive.ServerHttpResponse;
import reactor.core.publisher.Mono;
import java.util.function.BiFunction;
/**
* Created by vuuvv on 2020-04-14
*/
public interface BodyHandlerFunction extends
BiFunction<ServerHttpResponse, Publisher<? extends DataBuffer>, Mono<Void>> {
}
package com.vuuvv.vmall.gateway.filter;
import org.reactivestreams.Publisher;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.ReactiveHttpOutputMessage;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.lang.NonNull;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
/**
* ServerHttpResponse包装类,通过BodyHandlerFunction处理响应body
*/
public class BodyHandlerServerHttpResponseDecorator
extends ServerHttpResponseDecorator {
/**
* body 处理拦截器
*/
private BodyHandlerFunction bodyHandler = initDefaultBodyHandler();
/**
* 构造函数
*/
public BodyHandlerServerHttpResponseDecorator(BodyHandlerFunction bodyHandler, ServerHttpResponse delegate) {
super(delegate);
if (bodyHandler != null) {
this.bodyHandler = bodyHandler;
}
}
@Override
public @NonNull Mono<Void> writeWith(@NonNull Publisher<? extends DataBuffer> body) {
//body 拦截处理器处理响应
return bodyHandler.apply(getDelegate(), body);
}
@Override
public @NonNull Mono<Void> writeAndFlushWith(@NonNull Publisher<? extends Publisher<? extends DataBuffer>> body) {
return writeWith(Flux.from(body).flatMapSequential(p -> p));
}
/**
* 默认body拦截处理器
*/
private BodyHandlerFunction initDefaultBodyHandler() {
return ReactiveHttpOutputMessage::writeWith;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment