admin / modules /db /dialect /dialect.go
AZLABS's picture
Upload folder using huggingface_hub
530729e verified
// Copyright 2019 GoAdmin Core Team. All rights reserved.
// Use of this source code is governed by a Apache-2.0 style
// license that can be found in the LICENSE file.
package dialect
import (
"strings"
"github.com/GoAdminGroup/go-admin/modules/config"
)
// Dialect is methods set of different driver.
type Dialect interface {
// GetName get dialect's name
GetName() string
// ShowColumns show columns of specified table
ShowColumns(table string) string
// ShowColumnsWithComment show columns with coment of specified table
ShowColumnsWithComment(schema, table string) string
// ShowTables show tables of database
ShowTables() string
// Insert
Insert(comp *SQLComponent) string
// Delete
Delete(comp *SQLComponent) string
// Update
Update(comp *SQLComponent) string
// Select
Select(comp *SQLComponent) string
// GetDelimiter return the delimiter of Dialect.
GetDelimiter() string
}
// GetDialect return the default Dialect.
func GetDialect() Dialect {
return GetDialectByDriver(config.GetDatabases().GetDefault().Driver)
}
// GetDialectByDriver return the Dialect of given driver.
func GetDialectByDriver(driver string) Dialect {
switch driver {
case "mysql":
return mysql{
commonDialect: commonDialect{delimiter: "`", delimiter2: "`"},
}
case "mssql":
return mssql{
commonDialect: commonDialect{delimiter: "[", delimiter2: "]"},
}
case "postgresql":
return postgresql{
commonDialect: commonDialect{delimiter: `"`, delimiter2: `"`},
}
case "sqlite":
return sqlite{
commonDialect: commonDialect{delimiter: "`", delimiter2: "`"},
}
case "oceanbase":
return oceanbase{
commonDialect: commonDialect{delimiter: "`", delimiter2: "`"},
}
default:
return commonDialect{delimiter: "`", delimiter2: "`"}
}
}
// H is a shorthand of map.
type H map[string]interface{}
// SQLComponent is a sql components set.
type SQLComponent struct {
Fields []string
Functions []string
TableName string
Wheres []Where
Leftjoins []Join
Args []interface{}
Order string
Offset string
Limit string
WhereRaws string
UpdateRaws []RawUpdate
Group string
Statement string
Values H
}
// Where contains the operation and field.
type Where struct {
Operation string
Field string
Qmark string
}
// Join contains the table and field and operation.
type Join struct {
Table string
FieldA string
Operation string
FieldB string
}
// RawUpdate contains the expression and arguments.
type RawUpdate struct {
Expression string
Args []interface{}
}
// *******************************
// internal help function
// *******************************
func (sql *SQLComponent) getLimit() string {
if sql.Limit == "" {
return ""
}
return " limit " + sql.Limit + " "
}
func (sql *SQLComponent) getOffset() string {
if sql.Offset == "" {
return ""
}
return " offset " + sql.Offset + " "
}
func (sql *SQLComponent) getOrderBy() string {
if sql.Order == "" {
return ""
}
return " order by " + sql.Order + " "
}
func (sql *SQLComponent) getGroupBy() string {
if sql.Group == "" {
return ""
}
return " group by " + sql.Group + " "
}
func (sql *SQLComponent) getJoins(delimiter, delimiter2 string) string {
if len(sql.Leftjoins) == 0 {
return ""
}
joins := ""
for _, join := range sql.Leftjoins {
joins += " left join " + wrap(delimiter, delimiter2, join.Table) + " on " +
sql.processLeftJoinField(join.FieldA, delimiter, delimiter2) + " " + join.Operation + " " +
sql.processLeftJoinField(join.FieldB, delimiter, delimiter2) + " "
}
return joins
}
func (sql *SQLComponent) processLeftJoinField(field, delimiter, delimiter2 string) string {
arr := strings.Split(field, ".")
if len(arr) > 0 {
return delimiter + arr[0] + delimiter2 + "." + delimiter + arr[1] + delimiter2
}
return field
}
func (sql *SQLComponent) getFields(delimiter, delimiter2 string) string {
if len(sql.Fields) == 0 {
return "*"
}
fields := ""
if len(sql.Leftjoins) == 0 {
for k, field := range sql.Fields {
if sql.Functions[k] != "" {
fields += sql.Functions[k] + "(" + wrap(delimiter, delimiter2, field) + "),"
} else {
fields += wrap(delimiter, delimiter2, field) + ","
}
}
} else {
for _, field := range sql.Fields {
arr := strings.Split(field, ".")
if len(arr) > 1 {
fields += wrap(delimiter, delimiter2, arr[0]) + "." + wrap(delimiter, delimiter2, arr[1]) + ","
} else {
fields += wrap(delimiter, delimiter2, field) + ","
}
}
}
return fields[:len(fields)-1]
}
func wrap(delimiter, delimiter2, field string) string {
if field == "*" {
return "*"
}
return delimiter + field + delimiter2
}
func (sql *SQLComponent) getWheres(delimiter, delimiter2 string) string {
if len(sql.Wheres) == 0 {
if sql.WhereRaws != "" {
return " where " + sql.WhereRaws
}
return ""
}
wheres := " where "
var arr []string
for _, where := range sql.Wheres {
arr = strings.Split(where.Field, ".")
if len(arr) > 1 {
wheres += arr[0] + "." + wrap(delimiter, delimiter2, arr[1]) + " " + where.Operation + " " + where.Qmark + " and "
} else {
wheres += wrap(delimiter, delimiter2, where.Field) + " " + where.Operation + " " + where.Qmark + " and "
}
}
if sql.WhereRaws != "" {
return wheres + sql.WhereRaws
}
return wheres[:len(wheres)-5]
}
func (sql *SQLComponent) prepareUpdate(delimiter, delimiter2 string) {
fields := ""
args := make([]interface{}, 0)
if len(sql.Values) != 0 {
for key, value := range sql.Values {
fields += wrap(delimiter, delimiter2, key) + " = ?, "
args = append(args, value)
}
if len(sql.UpdateRaws) == 0 {
fields = fields[:len(fields)-2]
} else {
for i := 0; i < len(sql.UpdateRaws); i++ {
if i == len(sql.UpdateRaws)-1 {
fields += sql.UpdateRaws[i].Expression + " "
} else {
fields += sql.UpdateRaws[i].Expression + ","
}
args = append(args, sql.UpdateRaws[i].Args...)
}
}
sql.Args = append(args, sql.Args...)
} else {
if len(sql.UpdateRaws) == 0 {
panic("prepareUpdate: wrong parameter")
} else {
for i := 0; i < len(sql.UpdateRaws); i++ {
if i == len(sql.UpdateRaws)-1 {
fields += sql.UpdateRaws[i].Expression + " "
} else {
fields += sql.UpdateRaws[i].Expression + ","
}
args = append(args, sql.UpdateRaws[i].Args...)
}
}
sql.Args = append(args, sql.Args...)
}
sql.Statement = "update " + delimiter + sql.TableName + delimiter2 + " set " + fields + sql.getWheres(delimiter, delimiter2)
}
func (sql *SQLComponent) prepareInsert(delimiter, delimiter2 string) {
fields := " ("
quesMark := "("
for key, value := range sql.Values {
fields += wrap(delimiter, delimiter2, key) + ","
quesMark += "?,"
sql.Args = append(sql.Args, value)
}
fields = fields[:len(fields)-1] + ")"
quesMark = quesMark[:len(quesMark)-1] + ")"
sql.Statement = "insert into " + delimiter + sql.TableName + delimiter2 + fields + " values " + quesMark
}