refactor: 移除不必要的依赖和初始化逻辑,简化配置更新处理

This commit is contained in:
Akizon77
2025-12-01 21:51:14 +08:00
parent 5b5cf32a28
commit f8b4e90d47
8 changed files with 107 additions and 96 deletions
-4
View File
@@ -24,10 +24,8 @@ import (
"github.com/komari-monitor/komari/internal/database/records"
"github.com/komari-monitor/komari/internal/database/tasks"
"github.com/komari-monitor/komari/internal/eventType"
"github.com/komari-monitor/komari/internal/geoip"
logutil "github.com/komari-monitor/komari/internal/log"
"github.com/komari-monitor/komari/internal/messageSender"
"github.com/komari-monitor/komari/internal/oauth"
"github.com/komari-monitor/komari/internal/patch"
"github.com/komari-monitor/komari/internal/restore"
"github.com/komari-monitor/komari/pkg/cloudflared"
@@ -80,10 +78,8 @@ func RunServer() {
event.Trigger(eventType.ServerInitializeStart, event.M{"config": config, "engine": r})
go geoip.InitGeoIp()
go DoScheduledWork()
go messageSender.Initialize()
go oauth.Initialize()
server.StartNezhaGRPCServer(config.NezhaCompatListen)
+2 -2
View File
@@ -45,11 +45,11 @@ func Override(cst Config) error {
oldConf := *Conf
Conf = &cst
event.Trigger(eventType.ConfigUpdated, event.M{
err, _ = event.Trigger(eventType.ConfigUpdated, event.M{
"old": oldConf,
"new": cst,
})
return nil
return err
}
func SavePartial(cst map[string]interface{}) error {
+1 -6
View File
@@ -59,10 +59,5 @@ func SaveLoadNotification(record models.LoadNotification) error {
}
func ReloadLoadNotificationSchedule() error {
db := dbcore.GetDBInstance()
var loadNotifications []models.LoadNotification
if err := db.Find(&loadNotifications).Error; err != nil {
return err
}
return notifier.ReloadLoadNotificationSchedule(loadNotifications)
return notifier.ReloadLoadNotification()
}
+52 -35
View File
@@ -7,7 +7,9 @@ import (
"time"
"unicode"
"github.com/gookit/event"
"github.com/komari-monitor/komari/internal/conf"
"github.com/komari-monitor/komari/internal/eventType"
"github.com/patrickmn/go-cache"
)
@@ -22,44 +24,31 @@ type GeoInfo struct {
func init() {
CurrentProvider = &EmptyProvider{}
geoCache = cache.New(48*time.Hour, 1*time.Hour)
}
// GeoIPService 接口定义了获取地理位置信息的核心方法。
// 任何实现此接口的类型都可以作为地理位置服务提供者。
type GeoIPService interface {
Name() string
GetGeoInfo(ip net.IP) (*GeoInfo, error)
UpdateDatabase() error
Close() error
}
func GetRegionUnicodeEmoji(isoCode string) string {
if len(isoCode) != 2 {
return ""
}
isoCode = strings.ToUpper(isoCode)
if !unicode.IsLetter(rune(isoCode[0])) || !unicode.IsLetter(rune(isoCode[1])) {
return ""
}
rune1 := rune(0x1F1E6 + (rune(isoCode[0]) - 'A'))
rune2 := rune(0x1F1E6 + (rune(isoCode[1]) - 'A'))
return string(rune1) + string(rune2)
}
func InitGeoIp() {
config, err := conf.GetWithV1Format()
err := SetProvider(conf.Conf.GeoIp.GeoIpProvider)
if err != nil {
panic("Failed to get configuration for GeoIP: " + err.Error())
log.Printf("Failed to set initial GeoIP provider: %v", err)
}
if !config.GeoIpEnabled {
return
}
switch config.GeoIpProvider {
event.On(eventType.ConfigUpdated, event.ListenerFunc(func(e event.Event) error {
oldConf, newConf, err := conf.FromEvent(e)
if oldConf.GeoIp.GeoIpProvider == newConf.GeoIp.GeoIpProvider {
return nil
}
if err != nil {
log.Printf("Failed to parse config from event: %v", err)
return err
}
err = SetProvider(newConf.GeoIp.GeoIpProvider)
if err != nil {
log.Printf("Failed to set GeoIP provider: %v", err)
}
return nil
}))
}
func SetProvider(provider string) error {
switch provider {
case "mmdb":
NewCurrentProvider, err := NewMaxMindGeoIPService()
if err != nil {
@@ -110,6 +99,34 @@ func InitGeoIp() {
default:
CurrentProvider = &EmptyProvider{}
}
return nil
}
// GeoIPService 接口定义了获取地理位置信息的核心方法。
// 任何实现此接口的类型都可以作为地理位置服务提供者。
type GeoIPService interface {
Name() string
GetGeoInfo(ip net.IP) (*GeoInfo, error)
UpdateDatabase() error
Close() error
}
func GetRegionUnicodeEmoji(isoCode string) string {
if len(isoCode) != 2 {
return ""
}
isoCode = strings.ToUpper(isoCode)
if !unicode.IsLetter(rune(isoCode[0])) || !unicode.IsLetter(rune(isoCode[1])) {
return ""
}
rune1 := rune(0x1F1E6 + (rune(isoCode[0]) - 'A'))
rune2 := rune(0x1F1E6 + (rune(isoCode[1]) - 'A'))
return string(rune1) + string(rune2)
}
func GetGeoInfo(ip net.IP) (*GeoInfo, error) {
-16
View File
@@ -1,15 +1,11 @@
package messageSender
import (
"fmt"
"log"
"github.com/gookit/event"
"github.com/komari-monitor/komari/internal/conf"
"github.com/komari-monitor/komari/internal/database"
"github.com/komari-monitor/komari/internal/database/auditlog"
"github.com/komari-monitor/komari/internal/eventType"
"github.com/komari-monitor/komari/internal/oauth"
)
func init() {
@@ -19,18 +15,6 @@ func init() {
log.Printf("Failed to parse config from event: %v", err)
return err
}
if newConf.Login.OAuthProvider != oldConf.Login.OAuthProvider {
oidcProvider, err := database.GetOidcConfigByName(newConf.Login.OAuthProvider)
if err != nil {
log.Printf("Failed to get OIDC provider config: %v", err)
} else {
log.Printf("Using %s as OIDC provider", oidcProvider.Name)
}
err = oauth.LoadProvider(oidcProvider.Name, oidcProvider.Addition)
if err != nil {
auditlog.EventLog("error", fmt.Sprintf("Failed to load OIDC provider: %v", err))
}
}
if newConf.Notification.NotificationMethod != oldConf.Notification.NotificationMethod {
Initialize()
}
+6 -2
View File
@@ -218,7 +218,11 @@ func updateLastNotified(taskId uint, notifyTime time.Time) {
}
}
// ReloadLoadNotificationSchedule 加载或重载时间表
func ReloadLoadNotificationSchedule(loadNotifications []models.LoadNotification) error {
func ReloadLoadNotification() error {
db := dbcore.GetDBInstance()
var loadNotifications []models.LoadNotification
if err := db.Find(&loadNotifications).Error; err != nil {
return err
}
return LoadNotificationManager.Reload(loadNotifications)
}
+45 -27
View File
@@ -6,16 +6,18 @@ import (
"log"
"sync"
"github.com/gookit/event"
"github.com/komari-monitor/komari/internal/conf"
"github.com/komari-monitor/komari/internal/database"
"github.com/komari-monitor/komari/internal/database/auditlog"
"github.com/komari-monitor/komari/internal/database/models"
"github.com/komari-monitor/komari/internal/eventType"
"github.com/komari-monitor/komari/internal/oauth/factory"
)
var (
currentProvider factory.IOidcProvider
mu = sync.Mutex{}
once = sync.Once{}
)
func CurrentProvider() factory.IOidcProvider {
@@ -47,44 +49,60 @@ func LoadProvider(name string, configJson string) error {
return nil
}
func Initialize() error {
once.Do(func() {
all := factory.GetAllOidcProviders()
for _, provider := range all {
if _, err := database.GetOidcConfigByName(provider.GetName()); err == nil {
continue
}
// 如果数据库中没有该提供者的配置,则保存默认配置
config := provider.GetConfiguration()
configBytes, err := json.Marshal(config)
if err != nil {
log.Printf("Failed to marshal config for provider %s: %v", provider.GetName(), err)
return
}
if err := database.SaveOidcConfig(&models.OidcProvider{
Name: provider.GetName(),
Addition: string(configBytes),
}); err != nil {
log.Printf("Failed to save default config for provider %s: %v", provider.GetName(), err)
return
}
func init() {
all := factory.GetAllOidcProviders()
for _, provider := range all {
if _, err := database.GetOidcConfigByName(provider.GetName()); err == nil {
continue
}
})
// 如果数据库中没有该提供者的配置,则保存默认配置
config := provider.GetConfiguration()
configBytes, err := json.Marshal(config)
if err != nil {
log.Printf("Failed to marshal config for provider %s: %v", provider.GetName(), err)
return
}
if err := database.SaveOidcConfig(&models.OidcProvider{
Name: provider.GetName(),
Addition: string(configBytes),
}); err != nil {
log.Printf("Failed to save default config for provider %s: %v", provider.GetName(), err)
return
}
}
cfg, _ := conf.GetWithV1Format()
if cfg.OAuthProvider == "" || cfg.OAuthProvider == "none" {
LoadProvider("empty", "{}")
return nil
}
provider, err := database.GetOidcConfigByName(cfg.OAuthProvider)
if err != nil {
// 如果没有找到配置,使用empty provider
LoadProvider("empty", "{}")
return nil
}
err = LoadProvider(provider.Name, provider.Addition)
if err != nil {
log.Printf("Failed to load OIDC provider %s: %v", provider.Name, err)
return err
}
return nil
event.On(eventType.ConfigUpdated, event.ListenerFunc(func(e event.Event) error {
oldConf, newConf, err := conf.FromEvent(e)
if err != nil {
log.Printf("Failed to parse config from event: %v", err)
}
if newConf.Login.OAuthProvider != oldConf.Login.OAuthProvider {
oidcProvider, err := database.GetOidcConfigByName(newConf.Login.OAuthProvider)
if err != nil {
log.Printf("Failed to get OIDC provider config: %v", err)
} else {
log.Printf("Using %s as OIDC provider", oidcProvider.Name)
}
err = LoadProvider(oidcProvider.Name, oidcProvider.Addition)
if err != nil {
auditlog.EventLog("error", fmt.Sprintf("Failed to load OIDC provider: %v", err))
}
}
return nil
}))
}
+1 -4
View File
@@ -7,7 +7,6 @@ import (
"github.com/komari-monitor/komari/internal/api_rpc"
"github.com/komari-monitor/komari/internal/conf"
"github.com/komari-monitor/komari/internal/eventType"
"github.com/komari-monitor/komari/internal/geoip"
"github.com/komari-monitor/komari/internal/messageSender"
"github.com/komari-monitor/komari/public"
)
@@ -25,9 +24,7 @@ func Init(r *gin.Engine) {
oldConf := e.Get("old").(conf.Config)
AllowCors = newConf.Site.AllowCors
public.UpdateIndex(newConf.ToV1Format())
if newConf.GeoIp.GeoIpProvider != oldConf.GeoIp.GeoIpProvider {
go geoip.InitGeoIp()
}
if newConf.Notification.NotificationMethod != oldConf.Notification.NotificationMethod {
go messageSender.Initialize()
}