package com.hz.employmentsite.util.datarange; import com.hz.employmentsite.enums.DataRangeEnum; import com.hz.employmentsite.services.service.AccountService; import com.hz.employmentsite.util.CustomerApplicationContext; import com.hz.employmentsite.util.datarange.annotations.*; import com.hz.employmentsite.vo.user.DataRange; import org.apache.ibatis.cache.CacheKey; import org.apache.ibatis.executor.Executor; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.ResultMap; import org.apache.ibatis.plugin.Interceptor; import org.apache.ibatis.plugin.Intercepts; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.plugin.Signature; import org.apache.ibatis.session.ResultHandler; import org.apache.ibatis.session.RowBounds; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @Intercepts({ @Signature( type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class} ), @Signature( type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class} )}) public class DataRangeInterceptor implements Interceptor { DataRangeUtils dataRangeUtils; AccountService accountService; public DataRangeInterceptor() { dataRangeUtils = CustomerApplicationContext.getBean(DataRangeUtils.class); accountService = CustomerApplicationContext.getBean(AccountService.class); } static int MAPPED_STATEMENT_INDEX = 0;// 这是对应上面的args的序号 static int PARAMETER_INDEX = 1; @Override public Object intercept(Invocation invocation) throws Throwable { Object[] args = invocation.getArgs(); MappedStatement ms = (MappedStatement) args[MAPPED_STATEMENT_INDEX]; Object parameter = args[PARAMETER_INDEX]; RowBounds rowBounds = (RowBounds) args[2]; ResultHandler resultHandler = (ResultHandler) args[3]; Executor executor = (Executor) invocation.getTarget(); CacheKey cacheKey; BoundSql boundSql; if (args.length == 4) { //4 个参数时 boundSql = ms.getBoundSql(parameter); cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql); } else { //6 个参数时 cacheKey = (CacheKey) args[4]; boundSql = (BoundSql) args[5]; } String mainSql = boundSql.getSql(); List resultMaps = ms.getResultMaps(); Class entityClass = resultMaps.get(0).getType(); if (Map.class.isAssignableFrom(entityClass)) { return invocation.proceed(); } else { String whereClause = ""; try { DataRange dataRange = this.dataRangeUtils.getCurrentRange(); if (dataRange == null || dataRange.getRange() == DataRangeEnum.ALL.getValue()) { return invocation.proceed(); } whereClause = getJoinSql(entityClass, dataRange); if (!whereClause.equals("")) { if(mainSql.contains("task.dotaskID,task.siteID")&&dataRange.getRange()==4){ whereClause += " or x.siteID is null "; } mainSql = "select x.* from" + "(" + mainSql + ") x " + whereClause; boundSql = new BoundSql(ms.getConfiguration(), mainSql, boundSql.getParameterMappings(), boundSql.getParameterObject()); } } catch (Exception ex) { return invocation.proceed(); } } try{ return executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql); } catch (Exception ex){ return invocation.proceed(); } } private String getJoinSql(Class entityClass, DataRange dataRange) { return getJoinSqlByField(getField(entityClass, dataRange), dataRange); } private Field getField(Class entityClass, DataRange dataRange) { Field field = null; DataRangeEnum dataRangeEnum = DataRangeEnum.getDataRangeByValue(dataRange.getRange()); switch (dataRangeEnum) { case City: break; case Region: field = Arrays.stream(entityClass.getDeclaredFields()).filter(x -> x.getAnnotationsByType(RegionID.class).length > 0) .findFirst().orElse(null); break; case Institution: field = Arrays.stream(entityClass.getDeclaredFields()).filter(x -> x.getAnnotationsByType(InstitutionID.class).length > 0) .findFirst().orElse(null); break; case Site: field = Arrays.stream(entityClass.getDeclaredFields()).filter(x -> x.getAnnotationsByType(SiteID.class).length > 0) .findFirst().orElse(null); break; case Company: field = Arrays.stream(entityClass.getDeclaredFields()).filter(x -> x.getAnnotationsByType(CompanyID.class).length > 0) .findFirst().orElse(null); break; case SELF: field = Arrays.stream(entityClass.getDeclaredFields()).filter(x -> x.getAnnotationsByType(UserID.class).length > 0) .findFirst().orElse(null); break; default: break; } return field; } private String getInStatement(List rangeIDList) { List idList = rangeIDList; if (idList == null) { idList = new ArrayList<>(); } String inStatement = String.join(",", idList.stream().map(x -> "'" + x + "'").toArray(String[]::new)); return inStatement; } private String getJoinSqlByField(Field keyField, DataRange dataRange) { if (keyField == null || dataRange.getRangeIDList() == null || dataRange.getRangeIDList().size() == 0) return "where 1=1"; String keyFieldName = keyField.getName(); String inStatement = getInStatement(dataRange.getRangeIDList()); return " where x." + keyFieldName + " in (" + inStatement + ")"; } }