|
|
|
|
|
|
|
|
|
package auth |
|
|
|
import ( |
|
"encoding/json" |
|
"net/http" |
|
"strconv" |
|
"time" |
|
|
|
"github.com/GoAdminGroup/go-admin/context" |
|
"github.com/GoAdminGroup/go-admin/modules/config" |
|
"github.com/GoAdminGroup/go-admin/modules/db" |
|
"github.com/GoAdminGroup/go-admin/modules/db/dialect" |
|
"github.com/GoAdminGroup/go-admin/modules/logger" |
|
"github.com/GoAdminGroup/go-admin/plugins/admin/modules" |
|
) |
|
|
|
const DefaultCookieKey = "go_admin_session" |
|
|
|
|
|
func newDBDriver(conn db.Connection) *DBDriver { |
|
return &DBDriver{ |
|
conn: conn, |
|
tableName: "goadmin_session", |
|
} |
|
} |
|
|
|
|
|
type PersistenceDriver interface { |
|
Load(string) (map[string]interface{}, error) |
|
Update(sid string, values map[string]interface{}) error |
|
} |
|
|
|
|
|
func GetSessionByKey(sesKey, key string, conn db.Connection) (interface{}, error) { |
|
m, err := newDBDriver(conn).Load(sesKey) |
|
return m[key], err |
|
} |
|
|
|
|
|
type Session struct { |
|
Expires time.Duration |
|
Cookie string |
|
Values map[string]interface{} |
|
Driver PersistenceDriver |
|
Sid string |
|
Context *context.Context |
|
} |
|
|
|
|
|
type Config struct { |
|
Expires time.Duration |
|
Cookie string |
|
} |
|
|
|
|
|
func (ses *Session) UpdateConfig(config Config) { |
|
ses.Expires = config.Expires |
|
ses.Cookie = config.Cookie |
|
} |
|
|
|
|
|
func (ses *Session) Get(key string) interface{} { |
|
return ses.Values[key] |
|
} |
|
|
|
|
|
func (ses *Session) Add(key string, value interface{}) error { |
|
ses.Values[key] = value |
|
if err := ses.Driver.Update(ses.Sid, ses.Values); err != nil { |
|
return err |
|
} |
|
cookie := http.Cookie{ |
|
Name: ses.Cookie, |
|
Value: ses.Sid, |
|
MaxAge: config.GetSessionLifeTime(), |
|
Expires: time.Now().Add(ses.Expires), |
|
HttpOnly: true, |
|
Path: "/", |
|
} |
|
if config.GetDomain() != "" { |
|
cookie.Domain = config.GetDomain() |
|
} |
|
ses.Context.SetCookie(&cookie) |
|
return nil |
|
} |
|
|
|
|
|
func (ses *Session) Clear() error { |
|
ses.Values = map[string]interface{}{} |
|
return ses.Driver.Update(ses.Sid, ses.Values) |
|
} |
|
|
|
|
|
func (ses *Session) UseDriver(driver PersistenceDriver) { |
|
ses.Driver = driver |
|
} |
|
|
|
|
|
func (ses *Session) StartCtx(ctx *context.Context) (*Session, error) { |
|
if cookie, err := ctx.Request.Cookie(ses.Cookie); err == nil && cookie.Value != "" { |
|
ses.Sid = cookie.Value |
|
valueFromDriver, err := ses.Driver.Load(cookie.Value) |
|
if err != nil { |
|
return nil, err |
|
} |
|
if len(valueFromDriver) > 0 { |
|
ses.Values = valueFromDriver |
|
} |
|
} else { |
|
ses.Sid = modules.Uuid() |
|
} |
|
ses.Context = ctx |
|
return ses, nil |
|
} |
|
|
|
|
|
func InitSession(ctx *context.Context, conn db.Connection) (*Session, error) { |
|
|
|
sessions := new(Session) |
|
sessions.UpdateConfig(Config{ |
|
Expires: time.Second * time.Duration(config.GetSessionLifeTime()), |
|
Cookie: DefaultCookieKey, |
|
}) |
|
|
|
sessions.UseDriver(newDBDriver(conn)) |
|
sessions.Values = make(map[string]interface{}) |
|
|
|
return sessions.StartCtx(ctx) |
|
} |
|
|
|
|
|
type DBDriver struct { |
|
conn db.Connection |
|
tableName string |
|
} |
|
|
|
|
|
func (driver *DBDriver) Load(sid string) (map[string]interface{}, error) { |
|
sesModel, err := driver.table().Where("sid", "=", sid).First() |
|
|
|
if db.CheckError(err, db.QUERY) { |
|
return nil, err |
|
} |
|
|
|
if sesModel == nil { |
|
return map[string]interface{}{}, nil |
|
} |
|
|
|
var values map[string]interface{} |
|
err = json.Unmarshal([]byte(sesModel["values"].(string)), &values) |
|
return values, err |
|
} |
|
|
|
func (driver *DBDriver) deleteOverdueSession() { |
|
|
|
defer func() { |
|
if err := recover(); err != nil { |
|
logger.Error(err) |
|
panic(err) |
|
} |
|
}() |
|
|
|
var ( |
|
duration = strconv.Itoa(config.GetSessionLifeTime() + 1000) |
|
driverName = config.GetDatabases().GetDefault().Driver |
|
raw = `` |
|
) |
|
|
|
if db.DriverPostgresql == driverName { |
|
raw = `extract(epoch from now()) - ` + duration + ` > extract(epoch from created_at)` |
|
} else if db.DriverMysql == driverName { |
|
raw = `unix_timestamp(created_at) < unix_timestamp() - ` + duration |
|
} else if db.DriverSqlite == driverName { |
|
raw = `strftime('%s', created_at) < strftime('%s', 'now') - ` + duration |
|
} else if db.DriverMssql == driverName { |
|
raw = `DATEDIFF(second, [created_at], GETDATE()) > ` + duration |
|
} else if db.DriverOceanBase == driverName { |
|
raw = `unix_timestamp(created_at) < unix_timestamp() - ` + duration |
|
} |
|
|
|
if raw != "" { |
|
_ = driver.table().WhereRaw(raw).Delete() |
|
} |
|
} |
|
|
|
|
|
func (driver *DBDriver) Update(sid string, values map[string]interface{}) error { |
|
|
|
go driver.deleteOverdueSession() |
|
|
|
if sid != "" { |
|
if len(values) == 0 { |
|
err := driver.table().Where("sid", "=", sid).Delete() |
|
if db.CheckError(err, db.DELETE) { |
|
return err |
|
} |
|
} |
|
valuesByte, err := json.Marshal(values) |
|
if err != nil { |
|
return err |
|
} |
|
sesValue := string(valuesByte) |
|
sesModel, _ := driver.table().Where("sid", "=", sid).First() |
|
if sesModel == nil { |
|
if !config.GetNoLimitLoginIP() { |
|
err = driver.table().Where("values", "=", sesValue).Delete() |
|
if db.CheckError(err, db.DELETE) { |
|
return err |
|
} |
|
} |
|
_, err := driver.table().Insert(dialect.H{ |
|
"values": sesValue, |
|
"sid": sid, |
|
}) |
|
if db.CheckError(err, db.INSERT) { |
|
return err |
|
} |
|
} else { |
|
_, err := driver.table(). |
|
Where("sid", "=", sid). |
|
Update(dialect.H{ |
|
"values": sesValue, |
|
}) |
|
if db.CheckError(err, db.UPDATE) { |
|
return err |
|
} |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
func (driver *DBDriver) table() *db.SQL { |
|
return db.Table(driver.tableName).WithDriver(driver.conn) |
|
} |
|
|