Files
komari/api/TokenAuthMiddleware.go
T
akizon77 7958b528ed feat: 更新代码库,替换包路径为 komari-monitor,重构部分功能
- 修改多个文件中的包导入路径,替换为 github.com/komari-monitor/komari
- 移除不再使用的 custom 包,改用 config 包
- 在 client.go 中重构 CreateClient 函数,简化参数
- 在 settings.go 中更新获取和保存配置的逻辑
- 在 client.go 中添加新的字段以支持更多客户端信息
- 在 sessions.go 中添加通过会话获取用户的功能
- 删除不必要的代码和文件,优化代码结构
2025-04-23 19:32:18 +08:00

82 lines
2.0 KiB
Go

package api
import (
"bytes"
"database/sql"
"encoding/json"
"io"
"net/http"
"github.com/komari-monitor/komari/database/clients"
"gorm.io/gorm"
"github.com/gin-gonic/gin"
)
// TokenAuthMiddleware creates a Gin middleware that validates a token from query parameters or request body.
func TokenAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var token string
// Step 1: Check query parameter for token
token = c.Query("token")
// Step 2: If no token in query, check request body for non-GET requests
if token == "" && c.Request.Method != http.MethodGet {
// Read the body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"status": "error", "error": "failed to read request body"})
return
}
// Restore the body for downstream handlers
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Try to extract token from JSON body
var bodyMap map[string]interface{}
if len(bodyBytes) > 0 {
if err := json.Unmarshal(bodyBytes, &bodyMap); err == nil {
if tokenVal, exists := bodyMap["token"]; exists {
if str, ok := tokenVal.(string); ok && str != "" {
token = str
}
}
}
}
}
if token == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"status": "error", "error": "token is required"})
return
}
exists, err := checkTokenExists(token)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to validate token"})
return
}
if !exists {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"status": "error", "error": "invalid token"})
return
}
c.Next()
}
}
func checkTokenExists(token string) (bool, error) {
_, err := clients.GetClientUUIDByToken(token)
if err == sql.ErrNoRows {
return false, nil
}
if err == gorm.ErrRecordNotFound {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}