
使用ThreadLocal实现用户身份认证
什么是ThreadLocal
ThreadLocal
是Java语言提供的一种线程局部变量(Thread-Local Variables)机制,它可以为每个使用该变量的线程提供一个独立的变量副本。ThreadLocal
类在java.lang包中,它提供了线程局部变量的功能,这些局部变量与普通的变量不同,ThreadLocal
为每一个使用该变量的线程提供了一个独立初始化的变量副本。
ThreadLocal
的主要作用是提供线程内部的局部变量,这种变量在多线程环境下访问时能保证各个线程之间的数据隔离,即每个线程看到的是自己独立的ThreadLocal
变量副本,因此不会发生多线程间的数据共享问题。这在进行一些如用户身份认证、事务管理等需要隔离处理的操作时非常有用。
基本用法
- 初始化:可以通过继承
ThreadLocal
类并重写initialValue()
方法来初始化线程局部变量的初始值,或者通过调用ThreadLocal
的withInitial(Supplier<? extends T> supplier)
方法来提供一个供应商函数初始化值。 - 访问:通过
get()
方法来访问当前线程的局部变量副本的值,通过set(T value)
方法来设置当前线程的局部变量副本的值。 - 清除:通过
remove()
方法来清除当前线程的局部变量副本。
public class ThreadLocalExample {
// 创建一个ThreadLocal实例,用于存储每个线程的用户ID
private static final ThreadLocal<Integer> userId = ThreadLocal.withInitial(() -> null);
public static void main(String[] args) throws InterruptedException {
// 线程1:设置并获取用户ID
new Thread(() -> {
// 为当前线程设置用户ID
userId.set(1);
System.out.println("Thread 1: UserID = " + userId.get());
// 最后清除ThreadLocal中的数据
userId.remove();
}).start();
// 线程2:设置并获取用户ID
new Thread(() -> {
// 为当前线程设置用户ID
userId.set(2);
System.out.println("Thread 2: UserID = " + userId.get());
// 最后清除ThreadLocal中的数据
userId.remove();
}).start();
// 主线程等待上面的线程执行完毕
Thread.sleep(1000); // 确保上面的线程执行完毕,实际应用中应避免使用Thread.sleep()来等待线程结束
}
}
使用场景
- 用户身份认证:在处理用户请求时,可以把用户的身份信息保存在
ThreadLocal
中,这样在当前线程的任何地方都可以很方便地访问到用户信息,而不需要在方法间传递这些信息。 - 数据库事务管理:可以用
ThreadLocal
来保存每个线程的数据库连接对象,确保在同一个线程中使用的是同一个数据库连接,从而方便进行事务管理。 - 性能监控:在进行性能监控时,可以利用
ThreadLocal
存储一些计时数据,以便在同一线程的不同执行点之间共享这些数据。
注意事项
虽然ThreadLocal
很有用,但也需要注意其潜在的内存泄露问题。由于ThreadLocal
的生命周期与线程一样长,如果线程不死亡,则ThreadLocal
变量和其持有的对象就不会被垃圾回收,特别是在使用线程池的情况下。因此,使用ThreadLocal
时,一定要在不再需要使用变量时调用remove()
方法来清除线程局部变量,以帮助垃圾回收器回收这部分内存。
项目环境介绍
- SpringBoot
- Spring Web
- Mysql Driver
可以设计表 t_user
的字段如下:
可以使用 Mybatis Plus 快速生成 CRUD 和 pojo
准备工具类
ThreadLocalUtil
:
/**
* ThreadLocal 工具类
*/
@SuppressWarnings("all")
public class ThreadLocalUtil {
//提供ThreadLocal对象,
private static final ThreadLocal THREAD_LOCAL = new ThreadLocal();
//根据键获取值
public static <T> T get() {
return (T) THREAD_LOCAL.get();
}
//存储键值对
public static void set(Object value) {
THREAD_LOCAL.set(value);
}
//清除ThreadLocal 防止内存泄漏
public static void remove() {
THREAD_LOCAL.remove();
}
}
JwtUtil
:
- 记得 pom 添加
auth0
的jwt
public class JwtUtil {
private static final String KEY = "itheima";
//接收业务数据,生成token并返回
public static String genToken(Map<String, Object> claims) {
return JWT.create()
.withClaim("claims", claims)
.withExpiresAt(new Date(System.currentTimeMillis() + 1000 * 60 * 60 * 12))
.sign(Algorithm.HMAC256(KEY));
}
//接收token,验证token,并返回业务数据
public static Map<String, Object> parseToken(String token) {
return JWT.require(Algorithm.HMAC256(KEY))
.build()
.verify(token)
.getClaim("claims")
.asMap();
}
}
Md5Util
:
public class Md5Util {
/**
* 默认的密码字符串组合,用来将字节转换成 16 进制表示的字符,apache校验下载的文件的正确性用的就是默认的这个组合
*/
protected static char hexDigits[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
protected static MessageDigest messagedigest = null;
static {
try {
messagedigest = MessageDigest.getInstance("MD5");
} catch (NoSuchAlgorithmException nsaex) {
System.err.println(Md5Util.class.getName() + "初始化失败,MessageDigest不支持MD5Util。");
nsaex.printStackTrace();
}
}
/**
* 生成字符串的md5校验值
*
* @param s
* @return
*/
public static String getMD5String(String s) {
return getMD5String(s.getBytes());
}
/**
* 判断字符串的md5校验码是否与一个已知的md5码相匹配
*
* @param password 要校验的字符串
* @param md5PwdStr 已知的md5校验码
* @return
*/
public static boolean checkPassword(String password, String md5PwdStr) {
String s = getMD5String(password);
return s.equals(md5PwdStr);
}
public static String getMD5String(byte[] bytes) {
messagedigest.update(bytes);
return bufferToHex(messagedigest.digest());
}
private static String bufferToHex(byte bytes[]) {
return bufferToHex(bytes, 0, bytes.length);
}
private static String bufferToHex(byte bytes[], int m, int n) {
StringBuffer stringbuffer = new StringBuffer(2 * n);
int k = m + n;
for (int l = m; l < k; l++) {
appendHexPair(bytes[l], stringbuffer);
}
return stringbuffer.toString();
}
private static void appendHexPair(byte bt, StringBuffer stringbuffer) {
char c0 = hexDigits[(bt & 0xf0) >> 4];// 取字节中高 4 位的数字转换, >>>
// 为逻辑右移,将符号位一起右移,此处未发现两种符号有何不同
char c1 = hexDigits[bt & 0xf];// 取字节中低 4 位的数字转换
stringbuffer.append(c0);
stringbuffer.append(c1);
}
}
添加拦截器拦截请求
LoginInterceptor
:
@Configuration
public class WebConfig implements WebMvcConfigurer {
private final LoginInterceptor loginInterceptor;
@Value("${server.baseUrl}")
private String url;
public WebConfig(LoginInterceptor loginInterceptor) {
this.loginInterceptor = loginInterceptor;
}
@Override
public void addCorsMappings(CorsRegistry registry) {
registry.addMapping("/**")
.allowedOrigins(url) // 替换为你的前端地址
.allowedMethods("GET", "POST", "PUT", "DELETE", "OPTIONS")
.allowedHeaders("*")
.allowCredentials(true);
}
@Override
public void addInterceptors(InterceptorRegistry registry) {
// 登录接口和注册接口不拦截
registry.addInterceptor(loginInterceptor).excludePathPatterns("/user/login", "/user/register", "/label", "/category");
}
}
添加webConfig设置放行地址
WebConfig
:
@Configuration
public class WebConfig implements WebMvcConfigurer {
private final LoginInterceptor loginInterceptor;
@Value("${server.baseUrl}")
private String url;
public WebConfig(LoginInterceptor loginInterceptor) {
this.loginInterceptor = loginInterceptor;
}
@Override
public void addCorsMappings(CorsRegistry registry) {
registry.addMapping("/**")
.allowedOrigins(url) // 替换为你的前端地址
.allowedMethods("GET", "POST", "PUT", "DELETE", "OPTIONS")
.allowedHeaders("*")
.allowCredentials(true);
}
@Override
public void addInterceptors(InterceptorRegistry registry) {
// 登录接口和注册接口不拦截
registry.addInterceptor(loginInterceptor).excludePathPatterns("/user/login", "/user/register", "/label", "/category");
}
}
设置异常处理器
GlobalExceptionHandler
:
@Slf4j
@RestControllerAdvice
public class GlobalExceptionHandler {
@ExceptionHandler(Exception.class)
public Result<Object> handleException(Exception e){
e.printStackTrace();
return Result.error(StringUtils.hasLength(e.getMessage())? e.getMessage() : "操作失败");
}
}
使用步骤
创建对应的controller
UserController
:
@RestController
@RequestMapping("/user")
@Validated
public class UserController {
@Autowired
private TUserService userService;
@PostMapping("/register")
public Result register(@Pattern(regexp = "^\\S{5,16}$") String username, @Pattern(regexp = "^\\S{5,16}$") String password) {
// 查询用户
TUser u = userService.findByUserName(username);
if (u == null) {
// 没有占用
//注册
userService.register(username, password);
return Result.ok();
} else {
// 占用
return Result.error("用户名已被占用");
}
}
@PostMapping("/login")
public Result login(String username, String password) {
// 查询用户
TUser loginUser = userService.findByUserName(username);
if (loginUser == null) {
return Result.error("用户名错误");
}
if (Md5Util.getMD5String(password).equals(loginUser.getPassword())) {
// 登录成功
// 生成token
Map<String, Object> claims = new HashMap<>();
claims.put("id", loginUser.getId());
claims.put("username", loginUser.getUsername());
claims.put("role", loginUser.getRole());
String token = JwtUtil.genToken(claims);
return Result.ok(token);
}
return Result.error("密码错误");
}
@GetMapping("/userInfo")
public Result<TUser> userInfo() {
// 根据用户名查询用户
Map<String, Object> map = ThreadLocalUtil.get();
String username = (String) map.get("username");
TUser user = userService.findByUserName(username);
user.setRoleName(userService.findByRoleName(user.getRole()));
return Result.ok(user);
}
@PutMapping("/update")
public Result update(@RequestBody @Validated TUser user) {
userService.update(user);
return Result.ok();
}
@PatchMapping("/updateAvatar")
public Result updateAvatar(@RequestParam @URL String avatarUrl) {
userService.updateAvatar(avatarUrl);
return Result.ok();
}
@PatchMapping("/updatePwd")
public Result updatePwd(@RequestBody Map<String, String> params) {
// 1. 校验参数
String oldPwd = params.get("old_pwd");
String newPwd = params.get("new_pwd");
String rePwd = params.get("re_pwd");
if (!StringUtils.hasLength(oldPwd) || !StringUtils.hasLength(newPwd) || !StringUtils.hasLength(rePwd)) {
return Result.error("缺少必要的参数");
}
// 原密码是否正确
// 调用userService根据用户名拿到原密码,再和old_pwd对比
Map<String, Object> map = ThreadLocalUtil.get();
String username = (String) map.get("username");
TUser loginUser = userService.findByUserName(username);
if (!loginUser.getPassword().equals(Md5Util.getMD5String(oldPwd))) {
return Result.error("原密码填写不正确");
}
// newPwd 和 rePwd是否一样
if (!rePwd.equals(newPwd)) {
return Result.error("两次填写密码不一致");
}
// 2. 调用service 完成密码更新
userService.updatePwd(newPwd);
return Result.ok();
}
}
添加对应的service
serviceImpl
:
@Service
public class TUserServiceImpl extends ServiceImpl<TUserMapper, TUser>
implements TUserService {
final TUserMapper tUserMapper;
final TRoleService tRoleService;
public TUserServiceImpl(TRoleMapper tRoleMapper, TUserMapper tUserMapper, TRoleService tRoleService) {
this.tUserMapper = tUserMapper;
this.tRoleService = tRoleService;
}
/**
* 根据用户名查询用户
*
* @param username
* @return
*/
public TUser findByUserName(String username) {
return query().eq("username", username).one();
}
/**
* 注册
*
* @param username
* @param password
*/
public void register(String username, String password) {
TUser tUser = new TUser();
tUser.setUpdateTime(LocalDateTime.now());
tUser.setCreateTime(LocalDateTime.now());
tUser.setUsername(username);
tUser.setPassword(Md5Util.getMD5String(password));
tUser.setAvatar("https://picgo.cn-sy1.rains3.com/2024/08/a3afdbb7f0c3ada619fdfe7d16692fab.jpg");
save(tUser);
}
/**
* 更新用户
*
* @param user
*/
@Override
public void update(TUser user) {
user.setUpdateTime(LocalDateTime.now());
updateById(user);
}
/**
* 更新用户头像
*
* @param avatarUrl
*/
@Override
public void updateAvatar(String avatarUrl) {
Map<String, Object> map = ThreadLocalUtil.get();
Integer id = (Integer) map.get("id");
tUserMapper.updateAvatar(avatarUrl, id);
}
/**
* 更新密码
*
* @param newPwd
*/
@Override
public void updatePwd(String newPwd) {
Map<String, Object> map = ThreadLocalUtil.get();
Integer id = (Integer) map.get("id");
tUserMapper.updatePwd(Md5Util.getMD5String(newPwd), id);
}
@Override
public String findByRoleName(Integer id) {
String desc = tRoleService.query().eq("id", id).select("role_name").one().getRoleName();
return desc != null ? desc : "游客";
}
}
- 感谢你赐予我前进的力量
赞赏者名单
因为你们的支持让我意识到写文章的价值🙏
本文是原创文章,采用 CC BY-NC-ND 4.0 协议,完整转载请注明来自 zxb
评论
隐私政策
你无需删除空行,直接评论以获取最佳展示效果