Commit 1dfff74d by Dan Cech Committed by Torkel Ödegaard

move database-specific code into dialects (#11884)

parent 27e1c674
CREATE LOGIN %%USER%% WITH PASSWORD = '%%PWD%%' CREATE LOGIN %%USER%% WITH PASSWORD = '%%PWD%%'
GO GO
CREATE DATABASE %%DB%%; CREATE DATABASE %%DB%%
ON
( NAME = %%DB%%,
FILENAME = '/var/opt/mssql/data/%%DB%%.mdf',
SIZE = 500MB,
MAXSIZE = 1000MB,
FILEGROWTH = 100MB )
LOG ON
( NAME = %%DB%%_log,
FILENAME = '/var/opt/mssql/data/%%DB%%_log.ldf',
SIZE = 500MB,
MAXSIZE = 1000MB,
FILEGROWTH = 100MB );
GO GO
USE %%DB%%; USE %%DB%%;
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
environment: environment:
ACCEPT_EULA: Y ACCEPT_EULA: Y
MSSQL_SA_PASSWORD: Password! MSSQL_SA_PASSWORD: Password!
MSSQL_PID: Express MSSQL_PID: Developer
MSSQL_DATABASE: grafana MSSQL_DATABASE: grafana
MSSQL_USER: grafana MSSQL_USER: grafana
MSSQL_PASSWORD: Password! MSSQL_PASSWORD: Password!
......
...@@ -114,7 +114,7 @@ func HandleAlertsQuery(query *m.GetAlertsQuery) error { ...@@ -114,7 +114,7 @@ func HandleAlertsQuery(query *m.GetAlertsQuery) error {
builder.Write(" ORDER BY name ASC") builder.Write(" ORDER BY name ASC")
if query.Limit != 0 { if query.Limit != 0 {
builder.Write(" LIMIT ?", query.Limit) builder.Write(dialect.Limit(query.Limit))
} }
alerts := make([]*m.AlertListItemDTO, 0) alerts := make([]*m.AlertListItemDTO, 0)
......
package sqlstore package sqlstore
import ( import (
"fmt"
"testing" "testing"
"github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/components/simplejson"
...@@ -21,7 +20,6 @@ func TestAlertNotificationSQLAccess(t *testing.T) { ...@@ -21,7 +20,6 @@ func TestAlertNotificationSQLAccess(t *testing.T) {
} }
err := GetAlertNotifications(cmd) err := GetAlertNotifications(cmd)
fmt.Printf("error %v", err)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(cmd.Result, ShouldBeNil) So(cmd.Result, ShouldBeNil)
}) })
......
...@@ -50,7 +50,7 @@ func (r *SqlAnnotationRepo) ensureTagsExist(sess *DBSession, tags []*models.Tag) ...@@ -50,7 +50,7 @@ func (r *SqlAnnotationRepo) ensureTagsExist(sess *DBSession, tags []*models.Tag)
var existingTag models.Tag var existingTag models.Tag
// check if it exists // check if it exists
if exists, err := sess.Table("tag").Where("`key`=? AND `value`=?", tag.Key, tag.Value).Get(&existingTag); err != nil { if exists, err := sess.Table("tag").Where(dialect.Quote("key")+"=? AND "+dialect.Quote("value")+"=?", tag.Key, tag.Value).Get(&existingTag); err != nil {
return nil, err return nil, err
} else if exists { } else if exists {
tag.Id = existingTag.Id tag.Id = existingTag.Id
...@@ -146,7 +146,7 @@ func (r *SqlAnnotationRepo) Find(query *annotations.ItemQuery) ([]*annotations.I ...@@ -146,7 +146,7 @@ func (r *SqlAnnotationRepo) Find(query *annotations.ItemQuery) ([]*annotations.I
params = append(params, query.OrgId) params = append(params, query.OrgId)
if query.AnnotationId != 0 { if query.AnnotationId != 0 {
fmt.Print("annotation query") // fmt.Print("annotation query")
sql.WriteString(` AND annotation.id = ?`) sql.WriteString(` AND annotation.id = ?`)
params = append(params, query.AnnotationId) params = append(params, query.AnnotationId)
} }
...@@ -193,10 +193,10 @@ func (r *SqlAnnotationRepo) Find(query *annotations.ItemQuery) ([]*annotations.I ...@@ -193,10 +193,10 @@ func (r *SqlAnnotationRepo) Find(query *annotations.ItemQuery) ([]*annotations.I
tags := models.ParseTagPairs(query.Tags) tags := models.ParseTagPairs(query.Tags)
for _, tag := range tags { for _, tag := range tags {
if tag.Value == "" { if tag.Value == "" {
keyValueFilters = append(keyValueFilters, "(tag.key = ?)") keyValueFilters = append(keyValueFilters, "(tag."+dialect.Quote("key")+" = ?)")
params = append(params, tag.Key) params = append(params, tag.Key)
} else { } else {
keyValueFilters = append(keyValueFilters, "(tag.key = ? AND tag.value = ?)") keyValueFilters = append(keyValueFilters, "(tag."+dialect.Quote("key")+" = ? AND tag."+dialect.Quote("value")+" = ?)")
params = append(params, tag.Key, tag.Value) params = append(params, tag.Key, tag.Value)
} }
} }
...@@ -219,7 +219,7 @@ func (r *SqlAnnotationRepo) Find(query *annotations.ItemQuery) ([]*annotations.I ...@@ -219,7 +219,7 @@ func (r *SqlAnnotationRepo) Find(query *annotations.ItemQuery) ([]*annotations.I
query.Limit = 100 query.Limit = 100
} }
sql.WriteString(fmt.Sprintf(" ORDER BY epoch DESC LIMIT %v", query.Limit)) sql.WriteString(" ORDER BY epoch DESC" + dialect.Limit(query.Limit))
items := make([]*annotations.ItemDTO, 0) items := make([]*annotations.ItemDTO, 0)
......
...@@ -10,12 +10,18 @@ import ( ...@@ -10,12 +10,18 @@ import (
) )
func TestSavingTags(t *testing.T) { func TestSavingTags(t *testing.T) {
Convey("Testing annotation saving/loading", t, func() {
InitTestDB(t) InitTestDB(t)
Convey("Testing annotation saving/loading", t, func() {
repo := SqlAnnotationRepo{} repo := SqlAnnotationRepo{}
Convey("Can save tags", func() { Convey("Can save tags", func() {
Reset(func() {
_, err := x.Exec("DELETE FROM annotation_tag WHERE 1=1")
So(err, ShouldBeNil)
})
tagPairs := []*models.Tag{ tagPairs := []*models.Tag{
{Key: "outage"}, {Key: "outage"},
{Key: "type", Value: "outage"}, {Key: "type", Value: "outage"},
...@@ -31,12 +37,19 @@ func TestSavingTags(t *testing.T) { ...@@ -31,12 +37,19 @@ func TestSavingTags(t *testing.T) {
} }
func TestAnnotations(t *testing.T) { func TestAnnotations(t *testing.T) {
Convey("Testing annotation saving/loading", t, func() {
InitTestDB(t) InitTestDB(t)
Convey("Testing annotation saving/loading", t, func() {
repo := SqlAnnotationRepo{} repo := SqlAnnotationRepo{}
Convey("Can save annotation", func() { Convey("Can save annotation", func() {
Reset(func() {
_, err := x.Exec("DELETE FROM annotation WHERE 1=1")
So(err, ShouldBeNil)
_, err = x.Exec("DELETE FROM annotation_tag WHERE 1=1")
So(err, ShouldBeNil)
})
annotation := &annotations.Item{ annotation := &annotations.Item{
OrgId: 1, OrgId: 1,
UserId: 1, UserId: 1,
......
...@@ -110,16 +110,15 @@ func TestDashboardSnapshotDBAccess(t *testing.T) { ...@@ -110,16 +110,15 @@ func TestDashboardSnapshotDBAccess(t *testing.T) {
} }
func TestDeleteExpiredSnapshots(t *testing.T) { func TestDeleteExpiredSnapshots(t *testing.T) {
Convey("Testing dashboard snapshots clean up", t, func() {
x := InitTestDB(t) x := InitTestDB(t)
Convey("Testing dashboard snapshots clean up", t, func() {
setting.SnapShotRemoveExpired = true setting.SnapShotRemoveExpired = true
notExpiredsnapshot := createTestSnapshot(x, "key1", 1000) notExpiredsnapshot := createTestSnapshot(x, "key1", 1200)
createTestSnapshot(x, "key2", -1000) createTestSnapshot(x, "key2", -1200)
createTestSnapshot(x, "key3", -1000) createTestSnapshot(x, "key3", -1200)
Convey("Clean up old dashboard snapshots", func() {
err := DeleteExpiredSnapshots(&m.DeleteExpiredSnapshotsCommand{}) err := DeleteExpiredSnapshots(&m.DeleteExpiredSnapshotsCommand{})
So(err, ShouldBeNil) So(err, ShouldBeNil)
...@@ -132,20 +131,18 @@ func TestDeleteExpiredSnapshots(t *testing.T) { ...@@ -132,20 +131,18 @@ func TestDeleteExpiredSnapshots(t *testing.T) {
So(len(query.Result), ShouldEqual, 1) So(len(query.Result), ShouldEqual, 1)
So(query.Result[0].Key, ShouldEqual, notExpiredsnapshot.Key) So(query.Result[0].Key, ShouldEqual, notExpiredsnapshot.Key)
})
Convey("Don't delete anything if there are no expired snapshots", func() { err = DeleteExpiredSnapshots(&m.DeleteExpiredSnapshotsCommand{})
err := DeleteExpiredSnapshots(&m.DeleteExpiredSnapshotsCommand{})
So(err, ShouldBeNil) So(err, ShouldBeNil)
query := m.GetDashboardSnapshotsQuery{ query = m.GetDashboardSnapshotsQuery{
OrgId: 1, OrgId: 1,
SignedInUser: &m.SignedInUser{OrgRole: m.ROLE_ADMIN}, SignedInUser: &m.SignedInUser{OrgRole: m.ROLE_ADMIN},
} }
SearchDashboardSnapshots(&query) SearchDashboardSnapshots(&query)
So(len(query.Result), ShouldEqual, 1) So(len(query.Result), ShouldEqual, 1)
}) So(query.Result[0].Key, ShouldEqual, notExpiredsnapshot.Key)
}) })
} }
...@@ -164,9 +161,11 @@ func createTestSnapshot(x *xorm.Engine, key string, expires int64) *m.DashboardS ...@@ -164,9 +161,11 @@ func createTestSnapshot(x *xorm.Engine, key string, expires int64) *m.DashboardS
So(err, ShouldBeNil) So(err, ShouldBeNil)
// Set expiry date manually - to be able to create expired snapshots // Set expiry date manually - to be able to create expired snapshots
if expires < 0 {
expireDate := time.Now().Add(time.Second * time.Duration(expires)) expireDate := time.Now().Add(time.Second * time.Duration(expires))
_, err = x.Exec("update dashboard_snapshot set expires = ? where "+dialect.Quote("key")+" = ?", expireDate, key) _, err = x.Exec("UPDATE dashboard_snapshot SET expires = ? WHERE id = ?", expireDate, cmd.Result.Id)
So(err, ShouldBeNil) So(err, ShouldBeNil)
}
return cmd.Result return cmd.Result
} }
...@@ -86,10 +86,7 @@ func addAnnotationMig(mg *Migrator) { ...@@ -86,10 +86,7 @@ func addAnnotationMig(mg *Migrator) {
// clear alert text // clear alert text
// //
updateTextFieldSql := "UPDATE annotation SET TEXT = '' WHERE alert_id > 0" updateTextFieldSql := "UPDATE annotation SET TEXT = '' WHERE alert_id > 0"
mg.AddMigration("Update alert annotations and set TEXT to empty", new(RawSqlMigration). mg.AddMigration("Update alert annotations and set TEXT to empty", NewRawSqlMigration(updateTextFieldSql))
Sqlite(updateTextFieldSql).
Postgres(updateTextFieldSql).
Mysql(updateTextFieldSql))
// //
// Add a 'created' & 'updated' column // Add a 'created' & 'updated' column
...@@ -111,8 +108,5 @@ func addAnnotationMig(mg *Migrator) { ...@@ -111,8 +108,5 @@ func addAnnotationMig(mg *Migrator) {
// Convert epoch saved as seconds to miliseconds // Convert epoch saved as seconds to miliseconds
// //
updateEpochSql := "UPDATE annotation SET epoch = (epoch*1000) where epoch < 9999999999" updateEpochSql := "UPDATE annotation SET epoch = (epoch*1000) where epoch < 9999999999"
mg.AddMigration("Convert existing annotations from seconds to milliseconds", new(RawSqlMigration). mg.AddMigration("Convert existing annotations from seconds to milliseconds", NewRawSqlMigration(updateEpochSql))
Sqlite(updateEpochSql).
Postgres(updateEpochSql).
Mysql(updateEpochSql))
} }
...@@ -45,8 +45,5 @@ INSERT INTO dashboard_acl ...@@ -45,8 +45,5 @@ INSERT INTO dashboard_acl
(-1,-1, 2,'Editor','2017-06-20','2017-06-20') (-1,-1, 2,'Editor','2017-06-20','2017-06-20')
` `
mg.AddMigration("save default acl rules in dashboard_acl table", new(RawSqlMigration). mg.AddMigration("save default acl rules in dashboard_acl table", NewRawSqlMigration(rawSQL))
Sqlite(rawSQL).
Postgres(rawSQL).
Mysql(rawSQL))
} }
...@@ -90,9 +90,7 @@ func addDashboardMigration(mg *Migrator) { ...@@ -90,9 +90,7 @@ func addDashboardMigration(mg *Migrator) {
mg.AddMigration("drop table dashboard_v1", NewDropTableMigration("dashboard_v1")) mg.AddMigration("drop table dashboard_v1", NewDropTableMigration("dashboard_v1"))
// change column type of dashboard.data // change column type of dashboard.data
mg.AddMigration("alter dashboard.data to mediumtext v1", new(RawSqlMigration). mg.AddMigration("alter dashboard.data to mediumtext v1", NewRawSqlMigration("").
Sqlite("SELECT 0 WHERE 0;").
Postgres("SELECT 0;").
Mysql("ALTER TABLE dashboard MODIFY data MEDIUMTEXT;")) Mysql("ALTER TABLE dashboard MODIFY data MEDIUMTEXT;"))
// add column to store updater of a dashboard // add column to store updater of a dashboard
...@@ -157,7 +155,7 @@ func addDashboardMigration(mg *Migrator) { ...@@ -157,7 +155,7 @@ func addDashboardMigration(mg *Migrator) {
Name: "uid", Type: DB_NVarchar, Length: 40, Nullable: true, Name: "uid", Type: DB_NVarchar, Length: 40, Nullable: true,
})) }))
mg.AddMigration("Update uid column values in dashboard", new(RawSqlMigration). mg.AddMigration("Update uid column values in dashboard", NewRawSqlMigration("").
Sqlite("UPDATE dashboard SET uid=printf('%09d',id) WHERE uid IS NULL;"). Sqlite("UPDATE dashboard SET uid=printf('%09d',id) WHERE uid IS NULL;").
Postgres("UPDATE dashboard SET uid=lpad('' || id,9,'0') WHERE uid IS NULL;"). Postgres("UPDATE dashboard SET uid=lpad('' || id,9,'0') WHERE uid IS NULL;").
Mysql("UPDATE dashboard SET uid=lpad(id,9,'0') WHERE uid IS NULL;")) Mysql("UPDATE dashboard SET uid=lpad(id,9,'0') WHERE uid IS NULL;"))
......
...@@ -50,9 +50,7 @@ func addDashboardSnapshotMigrations(mg *Migrator) { ...@@ -50,9 +50,7 @@ func addDashboardSnapshotMigrations(mg *Migrator) {
addTableIndicesMigrations(mg, "v5", snapshotV5) addTableIndicesMigrations(mg, "v5", snapshotV5)
// change column type of dashboard // change column type of dashboard
mg.AddMigration("alter dashboard_snapshot to mediumtext v2", new(RawSqlMigration). mg.AddMigration("alter dashboard_snapshot to mediumtext v2", NewRawSqlMigration("").
Sqlite("SELECT 0 WHERE 0;").
Postgres("SELECT 0;").
Mysql("ALTER TABLE dashboard_snapshot MODIFY dashboard MEDIUMTEXT;")) Mysql("ALTER TABLE dashboard_snapshot MODIFY dashboard MEDIUMTEXT;"))
mg.AddMigration("Update dashboard_snapshot table charset", NewTableCharsetMigration("dashboard_snapshot", []*Column{ mg.AddMigration("Update dashboard_snapshot table charset", NewTableCharsetMigration("dashboard_snapshot", []*Column{
......
...@@ -28,10 +28,7 @@ func addDashboardVersionMigration(mg *Migrator) { ...@@ -28,10 +28,7 @@ func addDashboardVersionMigration(mg *Migrator) {
// before new dashboards where created with version 0, now they are always inserted with version 1 // before new dashboards where created with version 0, now they are always inserted with version 1
const setVersionTo1WhereZeroSQL = `UPDATE dashboard SET version = 1 WHERE version = 0` const setVersionTo1WhereZeroSQL = `UPDATE dashboard SET version = 1 WHERE version = 0`
mg.AddMigration("Set dashboard version to 1 where 0", new(RawSqlMigration). mg.AddMigration("Set dashboard version to 1 where 0", NewRawSqlMigration(setVersionTo1WhereZeroSQL))
Sqlite(setVersionTo1WhereZeroSQL).
Postgres(setVersionTo1WhereZeroSQL).
Mysql(setVersionTo1WhereZeroSQL))
const rawSQL = `INSERT INTO dashboard_version const rawSQL = `INSERT INTO dashboard_version
( (
...@@ -54,14 +51,9 @@ SELECT ...@@ -54,14 +51,9 @@ SELECT
'', '',
dashboard.data dashboard.data
FROM dashboard;` FROM dashboard;`
mg.AddMigration("save existing dashboard data in dashboard_version table v1", new(RawSqlMigration). mg.AddMigration("save existing dashboard data in dashboard_version table v1", NewRawSqlMigration(rawSQL))
Sqlite(rawSQL).
Postgres(rawSQL).
Mysql(rawSQL))
// change column type of dashboard_version.data // change column type of dashboard_version.data
mg.AddMigration("alter dashboard_version.data to mediumtext v1", new(RawSqlMigration). mg.AddMigration("alter dashboard_version.data to mediumtext v1", NewRawSqlMigration("").
Sqlite("SELECT 0 WHERE 0;").
Postgres("SELECT 0;").
Mysql("ALTER TABLE dashboard_version MODIFY data MEDIUMTEXT;")) Mysql("ALTER TABLE dashboard_version MODIFY data MEDIUMTEXT;"))
} }
...@@ -122,10 +122,7 @@ func addDataSourceMigration(mg *Migrator) { ...@@ -122,10 +122,7 @@ func addDataSourceMigration(mg *Migrator) {
})) }))
const setVersionToOneWhereZero = `UPDATE data_source SET version = 1 WHERE version = 0` const setVersionToOneWhereZero = `UPDATE data_source SET version = 1 WHERE version = 0`
mg.AddMigration("Update initial version to 1", new(RawSqlMigration). mg.AddMigration("Update initial version to 1", NewRawSqlMigration(setVersionToOneWhereZero))
Sqlite(setVersionToOneWhereZero).
Postgres(setVersionToOneWhereZero).
Mysql(setVersionToOneWhereZero))
mg.AddMigration("Add read_only data column", NewAddColumnMigration(tableV2, &Column{ mg.AddMigration("Add read_only data column", NewAddColumnMigration(tableV2, &Column{
Name: "read_only", Type: DB_Bool, Nullable: true, Name: "read_only", Type: DB_Bool, Nullable: true,
......
...@@ -25,7 +25,7 @@ func TestMigrations(t *testing.T) { ...@@ -25,7 +25,7 @@ func TestMigrations(t *testing.T) {
x, err := xorm.NewEngine(testDB.DriverName, testDB.ConnStr) x, err := xorm.NewEngine(testDB.DriverName, testDB.ConnStr)
So(err, ShouldBeNil) So(err, ShouldBeNil)
sqlutil.CleanDB(x) NewDialect(x).CleanDB()
_, err = x.SQL(sql).Get(&r) _, err = x.SQL(sql).Get(&r)
So(err, ShouldNotBeNil) So(err, ShouldNotBeNil)
......
...@@ -85,8 +85,5 @@ func addOrgMigrations(mg *Migrator) { ...@@ -85,8 +85,5 @@ func addOrgMigrations(mg *Migrator) {
})) }))
const migrateReadOnlyViewersToViewers = `UPDATE org_user SET role = 'Viewer' WHERE role = 'Read Only Editor'` const migrateReadOnlyViewersToViewers = `UPDATE org_user SET role = 'Viewer' WHERE role = 'Read Only Editor'`
mg.AddMigration("Migrate all Read Only Viewers to Viewers", new(RawSqlMigration). mg.AddMigration("Migrate all Read Only Viewers to Viewers", NewRawSqlMigration(migrateReadOnlyViewersToViewers))
Sqlite(migrateReadOnlyViewersToViewers).
Postgres(migrateReadOnlyViewersToViewers).
Mysql(migrateReadOnlyViewersToViewers))
} }
...@@ -22,8 +22,7 @@ func addUserAuthMigrations(mg *Migrator) { ...@@ -22,8 +22,7 @@ func addUserAuthMigrations(mg *Migrator) {
// add indices // add indices
addTableIndicesMigrations(mg, "v1", userAuthV1) addTableIndicesMigrations(mg, "v1", userAuthV1)
mg.AddMigration("alter user_auth.auth_id to length 190", new(RawSqlMigration). mg.AddMigration("alter user_auth.auth_id to length 190", NewRawSqlMigration("").
Sqlite("SELECT 0 WHERE 0;").
Postgres("ALTER TABLE user_auth ALTER COLUMN auth_id TYPE VARCHAR(190);"). Postgres("ALTER TABLE user_auth ALTER COLUMN auth_id TYPE VARCHAR(190);").
Mysql("ALTER TABLE user_auth MODIFY auth_id VARCHAR(190);")) Mysql("ALTER TABLE user_auth MODIFY auth_id VARCHAR(190);"))
} }
...@@ -15,48 +15,9 @@ type Column struct { ...@@ -15,48 +15,9 @@ type Column struct {
} }
func (col *Column) String(d Dialect) string { func (col *Column) String(d Dialect) string {
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " return d.ColString(col)
sql += d.SqlType(col) + " "
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
if col.IsAutoIncrement {
sql += d.AutoIncrStr() + " "
}
}
if d.ShowCreateNull() {
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
}
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
return sql
} }
func (col *Column) StringNoPk(d Dialect) string { func (col *Column) StringNoPk(d Dialect) string {
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " return d.ColStringNoPk(col)
sql += d.SqlType(col) + " "
if d.ShowCreateNull() {
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
}
if col.Default != "" {
sql += "DEFAULT " + d.Default(col) + " "
}
return sql
} }
...@@ -3,11 +3,12 @@ package migrator ...@@ -3,11 +3,12 @@ package migrator
import ( import (
"fmt" "fmt"
"strings" "strings"
"github.com/go-xorm/xorm"
) )
type Dialect interface { type Dialect interface {
DriverName() string DriverName() string
QuoteStr() string
Quote(string) string Quote(string) string
AndStr() string AndStr() string
AutoIncrStr() string AutoIncrStr() string
...@@ -31,16 +32,29 @@ type Dialect interface { ...@@ -31,16 +32,29 @@ type Dialect interface {
TableCheckSql(tableName string) (string, []interface{}) TableCheckSql(tableName string) (string, []interface{})
RenameTable(oldName string, newName string) string RenameTable(oldName string, newName string) string
UpdateTableSql(tableName string, columns []*Column) string UpdateTableSql(tableName string, columns []*Column) string
ColString(*Column) string
ColStringNoPk(*Column) string
Limit(limit int64) string
LimitOffset(limit int64, offset int64) string
PreInsertId(table string, sess *xorm.Session) error
PostInsertId(table string, sess *xorm.Session) error
CleanDB() error
NoOpSql() string
} }
func NewDialect(name string) Dialect { func NewDialect(engine *xorm.Engine) Dialect {
name := engine.DriverName()
switch name { switch name {
case MYSQL: case MYSQL:
return NewMysqlDialect() return NewMysqlDialect(engine)
case SQLITE: case SQLITE:
return NewSqlite3Dialect() return NewSqlite3Dialect(engine)
case POSTGRES: case POSTGRES:
return NewPostgresDialect() return NewPostgresDialect(engine)
} }
panic("Unsupported database type: " + name) panic("Unsupported database type: " + name)
...@@ -48,6 +62,7 @@ func NewDialect(name string) Dialect { ...@@ -48,6 +62,7 @@ func NewDialect(name string) Dialect {
type BaseDialect struct { type BaseDialect struct {
dialect Dialect dialect Dialect
engine *xorm.Engine
driverName string driverName string
} }
...@@ -100,9 +115,12 @@ func (b *BaseDialect) CreateTableSql(table *Table) string { ...@@ -100,9 +115,12 @@ func (b *BaseDialect) CreateTableSql(table *Table) string {
} }
if len(pkList) > 1 { if len(pkList) > 1 {
sql += "PRIMARY KEY ( " quotedCols := []string{}
sql += b.dialect.Quote(strings.Join(pkList, b.dialect.Quote(","))) for _, col := range pkList {
sql += " ), " quotedCols = append(quotedCols, b.dialect.Quote(col))
}
sql += "PRIMARY KEY ( " + strings.Join(quotedCols, ",") + " ), "
} }
sql = sql[:len(sql)-2] + ")" sql = sql[:len(sql)-2] + ")"
...@@ -127,9 +145,12 @@ func (db *BaseDialect) CreateIndexSql(tableName string, index *Index) string { ...@@ -127,9 +145,12 @@ func (db *BaseDialect) CreateIndexSql(tableName string, index *Index) string {
idxName := index.XName(tableName) idxName := index.XName(tableName)
return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v);", unique, quotedCols := []string{}
quote(idxName), quote(tableName), for _, col := range index.Cols {
quote(strings.Join(index.Cols, quote(",")))) quotedCols = append(quotedCols, db.dialect.Quote(col))
}
return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v);", unique, quote(idxName), quote(tableName), strings.Join(quotedCols, ","))
} }
func (db *BaseDialect) QuoteColList(cols []string) string { func (db *BaseDialect) QuoteColList(cols []string) string {
...@@ -168,3 +189,74 @@ func (db *BaseDialect) DropIndexSql(tableName string, index *Index) string { ...@@ -168,3 +189,74 @@ func (db *BaseDialect) DropIndexSql(tableName string, index *Index) string {
func (db *BaseDialect) UpdateTableSql(tableName string, columns []*Column) string { func (db *BaseDialect) UpdateTableSql(tableName string, columns []*Column) string {
return "-- NOT REQUIRED" return "-- NOT REQUIRED"
} }
func (db *BaseDialect) ColString(col *Column) string {
sql := db.dialect.Quote(col.Name) + " "
sql += db.dialect.SqlType(col) + " "
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
if col.IsAutoIncrement {
sql += db.dialect.AutoIncrStr() + " "
}
}
if db.dialect.ShowCreateNull() {
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
}
if col.Default != "" {
sql += "DEFAULT " + db.dialect.Default(col) + " "
}
return sql
}
func (db *BaseDialect) ColStringNoPk(col *Column) string {
sql := db.dialect.Quote(col.Name) + " "
sql += db.dialect.SqlType(col) + " "
if db.dialect.ShowCreateNull() {
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
}
if col.Default != "" {
sql += "DEFAULT " + db.dialect.Default(col) + " "
}
return sql
}
func (db *BaseDialect) Limit(limit int64) string {
return fmt.Sprintf(" LIMIT %d", limit)
}
func (db *BaseDialect) LimitOffset(limit int64, offset int64) string {
return fmt.Sprintf(" LIMIT %d OFFSET %d", limit, offset)
}
func (db *BaseDialect) PreInsertId(table string, sess *xorm.Session) error {
return nil
}
func (db *BaseDialect) PostInsertId(table string, sess *xorm.Session) error {
return nil
}
func (db *BaseDialect) CleanDB() error {
return nil
}
func (db *BaseDialect) NoOpSql() string {
return "SELECT 0;"
}
...@@ -24,37 +24,58 @@ func (m *MigrationBase) GetCondition() MigrationCondition { ...@@ -24,37 +24,58 @@ func (m *MigrationBase) GetCondition() MigrationCondition {
type RawSqlMigration struct { type RawSqlMigration struct {
MigrationBase MigrationBase
sqlite string sql map[string]string
mysql string }
postgres string
func NewRawSqlMigration(sql string) *RawSqlMigration {
m := &RawSqlMigration{}
if sql != "" {
m.Default(sql)
}
return m
} }
func (m *RawSqlMigration) Sql(dialect Dialect) string { func (m *RawSqlMigration) Sql(dialect Dialect) string {
switch dialect.DriverName() { if m.sql != nil {
case MYSQL: if val := m.sql[dialect.DriverName()]; val != "" {
return m.mysql return val
case SQLITE: }
return m.sqlite
case POSTGRES: if val := m.sql["default"]; val != "" {
return m.postgres return val
}
} }
panic("db type not supported") return dialect.NoOpSql()
} }
func (m *RawSqlMigration) Sqlite(sql string) *RawSqlMigration { func (m *RawSqlMigration) Set(dialect string, sql string) *RawSqlMigration {
m.sqlite = sql if m.sql == nil {
m.sql = make(map[string]string)
}
m.sql[dialect] = sql
return m return m
} }
func (m *RawSqlMigration) Default(sql string) *RawSqlMigration {
return m.Set("default", sql)
}
func (m *RawSqlMigration) Sqlite(sql string) *RawSqlMigration {
return m.Set(SQLITE, sql)
}
func (m *RawSqlMigration) Mysql(sql string) *RawSqlMigration { func (m *RawSqlMigration) Mysql(sql string) *RawSqlMigration {
m.mysql = sql return m.Set(MYSQL, sql)
return m
} }
func (m *RawSqlMigration) Postgres(sql string) *RawSqlMigration { func (m *RawSqlMigration) Postgres(sql string) *RawSqlMigration {
m.postgres = sql return m.Set(POSTGRES, sql)
return m }
func (m *RawSqlMigration) Mssql(sql string) *RawSqlMigration {
return m.Set(MSSQL, sql)
} }
type AddColumnMigration struct { type AddColumnMigration struct {
......
...@@ -31,7 +31,7 @@ func NewMigrator(engine *xorm.Engine) *Migrator { ...@@ -31,7 +31,7 @@ func NewMigrator(engine *xorm.Engine) *Migrator {
mg.x = engine mg.x = engine
mg.Logger = log.New("migrator") mg.Logger = log.New("migrator")
mg.migrations = make([]Migration, 0) mg.migrations = make([]Migration, 0)
mg.dialect = NewDialect(mg.x.DriverName()) mg.dialect = NewDialect(mg.x)
return mg return mg
} }
......
package migrator package migrator
import ( import (
"fmt"
"strconv" "strconv"
"strings" "strings"
"github.com/go-xorm/xorm"
) )
type Mysql struct { type Mysql struct {
BaseDialect BaseDialect
} }
func NewMysqlDialect() *Mysql { func NewMysqlDialect(engine *xorm.Engine) *Mysql {
d := Mysql{} d := Mysql{}
d.BaseDialect.dialect = &d d.BaseDialect.dialect = &d
d.BaseDialect.engine = engine
d.BaseDialect.driverName = MYSQL d.BaseDialect.driverName = MYSQL
return &d return &d
} }
...@@ -24,10 +28,6 @@ func (db *Mysql) Quote(name string) string { ...@@ -24,10 +28,6 @@ func (db *Mysql) Quote(name string) string {
return "`" + name + "`" return "`" + name + "`"
} }
func (db *Mysql) QuoteStr() string {
return "`"
}
func (db *Mysql) AutoIncrStr() string { func (db *Mysql) AutoIncrStr() string {
return "AUTO_INCREMENT" return "AUTO_INCREMENT"
} }
...@@ -105,3 +105,23 @@ func (db *Mysql) UpdateTableSql(tableName string, columns []*Column) string { ...@@ -105,3 +105,23 @@ func (db *Mysql) UpdateTableSql(tableName string, columns []*Column) string {
return "ALTER TABLE " + db.Quote(tableName) + " " + strings.Join(statements, ", ") + ";" return "ALTER TABLE " + db.Quote(tableName) + " " + strings.Join(statements, ", ") + ";"
} }
func (db *Mysql) CleanDB() error {
tables, _ := db.engine.DBMetas()
sess := db.engine.NewSession()
defer sess.Close()
for _, table := range tables {
if _, err := sess.Exec("set foreign_key_checks = 0"); err != nil {
return fmt.Errorf("failed to disable foreign key checks")
}
if _, err := sess.Exec("drop table " + table.Name + " ;"); err != nil {
return fmt.Errorf("failed to delete table: %v, err: %v", table.Name, err)
}
if _, err := sess.Exec("set foreign_key_checks = 1"); err != nil {
return fmt.Errorf("failed to disable foreign key checks")
}
}
return nil
}
...@@ -4,15 +4,18 @@ import ( ...@@ -4,15 +4,18 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"github.com/go-xorm/xorm"
) )
type Postgres struct { type Postgres struct {
BaseDialect BaseDialect
} }
func NewPostgresDialect() *Postgres { func NewPostgresDialect(engine *xorm.Engine) *Postgres {
d := Postgres{} d := Postgres{}
d.BaseDialect.dialect = &d d.BaseDialect.dialect = &d
d.BaseDialect.engine = engine
d.BaseDialect.driverName = POSTGRES d.BaseDialect.driverName = POSTGRES
return &d return &d
} }
...@@ -25,10 +28,6 @@ func (db *Postgres) Quote(name string) string { ...@@ -25,10 +28,6 @@ func (db *Postgres) Quote(name string) string {
return "\"" + name + "\"" return "\"" + name + "\""
} }
func (db *Postgres) QuoteStr() string {
return "\""
}
func (b *Postgres) LikeStr() string { func (b *Postgres) LikeStr() string {
return "ILIKE" return "ILIKE"
} }
...@@ -117,8 +116,23 @@ func (db *Postgres) UpdateTableSql(tableName string, columns []*Column) string { ...@@ -117,8 +116,23 @@ func (db *Postgres) UpdateTableSql(tableName string, columns []*Column) string {
var statements = []string{} var statements = []string{}
for _, col := range columns { for _, col := range columns {
statements = append(statements, "ALTER "+db.QuoteStr()+col.Name+db.QuoteStr()+" TYPE "+db.SqlType(col)) statements = append(statements, "ALTER "+db.Quote(col.Name)+" TYPE "+db.SqlType(col))
} }
return "ALTER TABLE " + db.Quote(tableName) + " " + strings.Join(statements, ", ") + ";" return "ALTER TABLE " + db.Quote(tableName) + " " + strings.Join(statements, ", ") + ";"
} }
func (db *Postgres) CleanDB() error {
sess := db.engine.NewSession()
defer sess.Close()
if _, err := sess.Exec("DROP SCHEMA public CASCADE;"); err != nil {
return fmt.Errorf("Failed to drop schema public")
}
if _, err := sess.Exec("CREATE SCHEMA public;"); err != nil {
return fmt.Errorf("Failed to create schema public")
}
return nil
}
package migrator package migrator
import "fmt" import (
"fmt"
"github.com/go-xorm/xorm"
)
type Sqlite3 struct { type Sqlite3 struct {
BaseDialect BaseDialect
} }
func NewSqlite3Dialect() *Sqlite3 { func NewSqlite3Dialect(engine *xorm.Engine) *Sqlite3 {
d := Sqlite3{} d := Sqlite3{}
d.BaseDialect.dialect = &d d.BaseDialect.dialect = &d
d.BaseDialect.engine = engine
d.BaseDialect.driverName = SQLITE d.BaseDialect.driverName = SQLITE
return &d return &d
} }
...@@ -21,10 +26,6 @@ func (db *Sqlite3) Quote(name string) string { ...@@ -21,10 +26,6 @@ func (db *Sqlite3) Quote(name string) string {
return "`" + name + "`" return "`" + name + "`"
} }
func (db *Sqlite3) QuoteStr() string {
return "`"
}
func (db *Sqlite3) AutoIncrStr() string { func (db *Sqlite3) AutoIncrStr() string {
return "AUTOINCREMENT" return "AUTOINCREMENT"
} }
...@@ -77,3 +78,7 @@ func (db *Sqlite3) DropIndexSql(tableName string, index *Index) string { ...@@ -77,3 +78,7 @@ func (db *Sqlite3) DropIndexSql(tableName string, index *Index) string {
idxName := index.XName(tableName) idxName := index.XName(tableName)
return fmt.Sprintf("DROP INDEX %v", quote(idxName)) return fmt.Sprintf("DROP INDEX %v", quote(idxName))
} }
func (db *Sqlite3) CleanDB() error {
return nil
}
...@@ -9,6 +9,7 @@ const ( ...@@ -9,6 +9,7 @@ const (
POSTGRES = "postgres" POSTGRES = "postgres"
SQLITE = "sqlite3" SQLITE = "sqlite3"
MYSQL = "mysql" MYSQL = "mysql"
MSSQL = "mssql"
) )
type Migration interface { type Migration interface {
......
...@@ -64,7 +64,7 @@ func UpdatePlaylist(cmd *m.UpdatePlaylistCommand) error { ...@@ -64,7 +64,7 @@ func UpdatePlaylist(cmd *m.UpdatePlaylistCommand) error {
Interval: playlist.Interval, Interval: playlist.Interval,
} }
_, err := x.ID(cmd.Id).Cols("id", "name", "interval").Update(&playlist) _, err := x.ID(cmd.Id).Cols("name", "interval").Update(&playlist)
if err != nil { if err != nil {
return err return err
......
...@@ -92,7 +92,7 @@ func (sb *SearchBuilder) ToSql() (string, []interface{}) { ...@@ -92,7 +92,7 @@ func (sb *SearchBuilder) ToSql() (string, []interface{}) {
LEFT OUTER JOIN dashboard folder on folder.id = dashboard.folder_id LEFT OUTER JOIN dashboard folder on folder.id = dashboard.folder_id
LEFT OUTER JOIN dashboard_tag on dashboard.id = dashboard_tag.dashboard_id`) LEFT OUTER JOIN dashboard_tag on dashboard.id = dashboard_tag.dashboard_id`)
sb.sql.WriteString(" ORDER BY dashboard.title ASC LIMIT 5000") sb.sql.WriteString(" ORDER BY dashboard.title ASC" + dialect.Limit(5000))
return sb.sql.String(), sb.params return sb.sql.String(), sb.params
} }
...@@ -135,12 +135,11 @@ func (sb *SearchBuilder) buildTagQuery() { ...@@ -135,12 +135,11 @@ func (sb *SearchBuilder) buildTagQuery() {
// this ends the inner select (tag filtered part) // this ends the inner select (tag filtered part)
sb.sql.WriteString(` sb.sql.WriteString(`
GROUP BY dashboard.id HAVING COUNT(dashboard.id) >= ? GROUP BY dashboard.id HAVING COUNT(dashboard.id) >= ?
LIMIT ?) as ids ORDER BY dashboard.id` + dialect.Limit(int64(sb.limit)) + `) as ids
INNER JOIN dashboard on ids.id = dashboard.id INNER JOIN dashboard on ids.id = dashboard.id
`) `)
sb.params = append(sb.params, len(sb.tags)) sb.params = append(sb.params, len(sb.tags))
sb.params = append(sb.params, sb.limit)
} }
func (sb *SearchBuilder) buildMainQuery() { func (sb *SearchBuilder) buildMainQuery() {
...@@ -153,8 +152,7 @@ func (sb *SearchBuilder) buildMainQuery() { ...@@ -153,8 +152,7 @@ func (sb *SearchBuilder) buildMainQuery() {
sb.sql.WriteString(` WHERE `) sb.sql.WriteString(` WHERE `)
sb.buildSearchWhereClause() sb.buildSearchWhereClause()
sb.sql.WriteString(` LIMIT ?) as ids INNER JOIN dashboard on ids.id = dashboard.id `) sb.sql.WriteString(` ORDER BY dashboard.title` + dialect.Limit(int64(sb.limit)) + `) as ids INNER JOIN dashboard on ids.id = dashboard.id `)
sb.params = append(sb.params, sb.limit)
} }
func (sb *SearchBuilder) buildSearchWhereClause() { func (sb *SearchBuilder) buildSearchWhereClause() {
......
...@@ -4,13 +4,10 @@ import ( ...@@ -4,13 +4,10 @@ import (
"testing" "testing"
m "github.com/grafana/grafana/pkg/models" m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
. "github.com/smartystreets/goconvey/convey" . "github.com/smartystreets/goconvey/convey"
) )
func TestSearchBuilder(t *testing.T) { func TestSearchBuilder(t *testing.T) {
dialect = migrator.NewDialect("sqlite3")
Convey("Testing building a search", t, func() { Convey("Testing building a search", t, func() {
signedInUser := &m.SignedInUser{ signedInUser := &m.SignedInUser{
OrgId: 1, OrgId: 1,
...@@ -23,7 +20,7 @@ func TestSearchBuilder(t *testing.T) { ...@@ -23,7 +20,7 @@ func TestSearchBuilder(t *testing.T) {
sql, params := sb.IsStarred().WithTitle("test").ToSql() sql, params := sb.IsStarred().WithTitle("test").ToSql()
So(sql, ShouldStartWith, "SELECT") So(sql, ShouldStartWith, "SELECT")
So(sql, ShouldContainSubstring, "INNER JOIN dashboard on ids.id = dashboard.id") So(sql, ShouldContainSubstring, "INNER JOIN dashboard on ids.id = dashboard.id")
So(sql, ShouldEndWith, "ORDER BY dashboard.title ASC LIMIT 5000") So(sql, ShouldContainSubstring, "ORDER BY dashboard.title ASC")
So(len(params), ShouldBeGreaterThan, 0) So(len(params), ShouldBeGreaterThan, 0)
}) })
...@@ -31,7 +28,7 @@ func TestSearchBuilder(t *testing.T) { ...@@ -31,7 +28,7 @@ func TestSearchBuilder(t *testing.T) {
sql, params := sb.WithTags([]string{"tag1", "tag2"}).ToSql() sql, params := sb.WithTags([]string{"tag1", "tag2"}).ToSql()
So(sql, ShouldStartWith, "SELECT") So(sql, ShouldStartWith, "SELECT")
So(sql, ShouldContainSubstring, "LEFT OUTER JOIN dashboard_tag") So(sql, ShouldContainSubstring, "LEFT OUTER JOIN dashboard_tag")
So(sql, ShouldEndWith, "ORDER BY dashboard.title ASC LIMIT 5000") So(sql, ShouldContainSubstring, "ORDER BY dashboard.title ASC")
So(len(params), ShouldBeGreaterThan, 0) So(len(params), ShouldBeGreaterThan, 0)
}) })
}) })
......
package sqlstore package sqlstore
import ( import (
"reflect"
"time" "time"
"github.com/go-xorm/xorm" "github.com/go-xorm/xorm"
...@@ -67,3 +68,23 @@ func inTransactionWithRetry(callback dbTransactionFunc, retry int) error { ...@@ -67,3 +68,23 @@ func inTransactionWithRetry(callback dbTransactionFunc, retry int) error {
return nil return nil
} }
func (sess *DBSession) InsertId(bean interface{}) (int64, error) {
table := sess.DB().Mapper.Obj2Table(getTypeName(bean))
dialect.PreInsertId(table, sess.Session)
id, err := sess.Session.InsertOne(bean)
dialect.PostInsertId(table, sess.Session)
return id, err
}
func getTypeName(bean interface{}) (res string) {
t := reflect.TypeOf(bean)
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t.Name()
}
...@@ -20,7 +20,6 @@ import ( ...@@ -20,7 +20,6 @@ import (
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql"
"github.com/go-xorm/xorm" "github.com/go-xorm/xorm"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
...@@ -97,7 +96,7 @@ func NewEngine() *xorm.Engine { ...@@ -97,7 +96,7 @@ func NewEngine() *xorm.Engine {
func SetEngine(engine *xorm.Engine) (err error) { func SetEngine(engine *xorm.Engine) (err error) {
x = engine x = engine
dialect = migrator.NewDialect(x.DriverName()) dialect = migrator.NewDialect(x)
migrator := migrator.NewMigrator(x) migrator := migrator.NewMigrator(x)
migrations.AddMigrations(migrator) migrations.AddMigrations(migrator)
...@@ -116,7 +115,7 @@ func getEngine() (*xorm.Engine, error) { ...@@ -116,7 +115,7 @@ func getEngine() (*xorm.Engine, error) {
cnnstr := "" cnnstr := ""
switch DbCfg.Type { switch DbCfg.Type {
case "mysql": case migrator.MYSQL:
protocol := "tcp" protocol := "tcp"
if strings.HasPrefix(DbCfg.Host, "/") { if strings.HasPrefix(DbCfg.Host, "/") {
protocol = "unix" protocol = "unix"
...@@ -133,7 +132,7 @@ func getEngine() (*xorm.Engine, error) { ...@@ -133,7 +132,7 @@ func getEngine() (*xorm.Engine, error) {
mysql.RegisterTLSConfig("custom", tlsCert) mysql.RegisterTLSConfig("custom", tlsCert)
cnnstr += "&tls=custom" cnnstr += "&tls=custom"
} }
case "postgres": case migrator.POSTGRES:
var host, port = "127.0.0.1", "5432" var host, port = "127.0.0.1", "5432"
fields := strings.Split(DbCfg.Host, ":") fields := strings.Split(DbCfg.Host, ":")
if len(fields) > 0 && len(strings.TrimSpace(fields[0])) > 0 { if len(fields) > 0 && len(strings.TrimSpace(fields[0])) > 0 {
...@@ -153,7 +152,7 @@ func getEngine() (*xorm.Engine, error) { ...@@ -153,7 +152,7 @@ func getEngine() (*xorm.Engine, error) {
strings.Replace(DbCfg.ClientKeyPath, `'`, `\'`, -1), strings.Replace(DbCfg.ClientKeyPath, `'`, `\'`, -1),
strings.Replace(DbCfg.CaCertPath, `'`, `\'`, -1), strings.Replace(DbCfg.CaCertPath, `'`, `\'`, -1),
) )
case "sqlite3": case migrator.SQLITE:
if !filepath.IsAbs(DbCfg.Path) { if !filepath.IsAbs(DbCfg.Path) {
DbCfg.Path = filepath.Join(setting.DataPath, DbCfg.Path) DbCfg.Path = filepath.Join(setting.DataPath, DbCfg.Path)
} }
...@@ -230,16 +229,10 @@ func LoadConfig() { ...@@ -230,16 +229,10 @@ func LoadConfig() {
DbCfg.Path = sec.Key("path").MustString("data/grafana.db") DbCfg.Path = sec.Key("path").MustString("data/grafana.db")
} }
var (
dbSqlite = "sqlite"
dbMySql = "mysql"
dbPostgres = "postgres"
)
func InitTestDB(t *testing.T) *xorm.Engine { func InitTestDB(t *testing.T) *xorm.Engine {
selectedDb := dbSqlite selectedDb := migrator.SQLITE
// selectedDb := dbMySql // selectedDb := migrator.MYSQL
// selectedDb := dbPostgres // selectedDb := migrator.POSTGRES
var x *xorm.Engine var x *xorm.Engine
var err error var err error
...@@ -250,9 +243,9 @@ func InitTestDB(t *testing.T) *xorm.Engine { ...@@ -250,9 +243,9 @@ func InitTestDB(t *testing.T) *xorm.Engine {
} }
switch strings.ToLower(selectedDb) { switch strings.ToLower(selectedDb) {
case dbMySql: case migrator.MYSQL:
x, err = xorm.NewEngine(sqlutil.TestDB_Mysql.DriverName, sqlutil.TestDB_Mysql.ConnStr) x, err = xorm.NewEngine(sqlutil.TestDB_Mysql.DriverName, sqlutil.TestDB_Mysql.ConnStr)
case dbPostgres: case migrator.POSTGRES:
x, err = xorm.NewEngine(sqlutil.TestDB_Postgres.DriverName, sqlutil.TestDB_Postgres.ConnStr) x, err = xorm.NewEngine(sqlutil.TestDB_Postgres.DriverName, sqlutil.TestDB_Postgres.ConnStr)
default: default:
x, err = xorm.NewEngine(sqlutil.TestDB_Sqlite3.DriverName, sqlutil.TestDB_Sqlite3.ConnStr) x, err = xorm.NewEngine(sqlutil.TestDB_Sqlite3.DriverName, sqlutil.TestDB_Sqlite3.ConnStr)
...@@ -261,24 +254,29 @@ func InitTestDB(t *testing.T) *xorm.Engine { ...@@ -261,24 +254,29 @@ func InitTestDB(t *testing.T) *xorm.Engine {
x.DatabaseTZ = time.UTC x.DatabaseTZ = time.UTC
x.TZLocation = time.UTC x.TZLocation = time.UTC
// x.ShowSQL()
if err != nil { if err != nil {
t.Fatalf("Failed to init test database: %v", err) t.Fatalf("Failed to init test database: %v", err)
} }
sqlutil.CleanDB(x) dialect = migrator.NewDialect(x)
err = dialect.CleanDB()
if err != nil {
t.Fatalf("Failed to clean test db %v", err)
}
if err := SetEngine(x); err != nil { if err := SetEngine(x); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// x.ShowSQL()
return x return x
} }
func IsTestDbMySql() bool { func IsTestDbMySql() bool {
if db, present := os.LookupEnv("GRAFANA_TEST_DB"); present { if db, present := os.LookupEnv("GRAFANA_TEST_DB"); present {
return db == dbMySql return db == migrator.MYSQL
} }
return false return false
...@@ -286,7 +284,7 @@ func IsTestDbMySql() bool { ...@@ -286,7 +284,7 @@ func IsTestDbMySql() bool {
func IsTestDbPostgres() bool { func IsTestDbPostgres() bool {
if db, present := os.LookupEnv("GRAFANA_TEST_DB"); present { if db, present := os.LookupEnv("GRAFANA_TEST_DB"); present {
return db == dbPostgres return db == migrator.POSTGRES
} }
return false return false
......
package sqlutil package sqlutil
import (
"fmt"
"github.com/go-xorm/xorm"
)
type TestDB struct { type TestDB struct {
DriverName string DriverName string
ConnStr string ConnStr string
...@@ -15,34 +9,3 @@ var TestDB_Sqlite3 = TestDB{DriverName: "sqlite3", ConnStr: ":memory:"} ...@@ -15,34 +9,3 @@ var TestDB_Sqlite3 = TestDB{DriverName: "sqlite3", ConnStr: ":memory:"}
var TestDB_Mysql = TestDB{DriverName: "mysql", ConnStr: "grafana:password@tcp(localhost:3306)/grafana_tests?collation=utf8mb4_unicode_ci"} var TestDB_Mysql = TestDB{DriverName: "mysql", ConnStr: "grafana:password@tcp(localhost:3306)/grafana_tests?collation=utf8mb4_unicode_ci"}
var TestDB_Postgres = TestDB{DriverName: "postgres", ConnStr: "user=grafanatest password=grafanatest host=localhost port=5432 dbname=grafanatest sslmode=disable"} var TestDB_Postgres = TestDB{DriverName: "postgres", ConnStr: "user=grafanatest password=grafanatest host=localhost port=5432 dbname=grafanatest sslmode=disable"}
var TestDB_Mssql = TestDB{DriverName: "mssql", ConnStr: "server=localhost;port=1433;database=grafanatest;user id=grafana;password=Password!"} var TestDB_Mssql = TestDB{DriverName: "mssql", ConnStr: "server=localhost;port=1433;database=grafanatest;user id=grafana;password=Password!"}
func CleanDB(x *xorm.Engine) {
if x.DriverName() == "postgres" {
sess := x.NewSession()
defer sess.Close()
if _, err := sess.Exec("DROP SCHEMA public CASCADE;"); err != nil {
panic("Failed to drop schema public")
}
if _, err := sess.Exec("CREATE SCHEMA public;"); err != nil {
panic("Failed to create schema public")
}
} else if x.DriverName() == "mysql" {
tables, _ := x.DBMetas()
sess := x.NewSession()
defer sess.Close()
for _, table := range tables {
if _, err := sess.Exec("set foreign_key_checks = 0"); err != nil {
panic("failed to disable foreign key checks")
}
if _, err := sess.Exec("drop table " + table.Name + " ;"); err != nil {
panic(fmt.Sprintf("failed to delete table: %v, err: %v", table.Name, err))
}
if _, err := sess.Exec("set foreign_key_checks = 1"); err != nil {
panic("failed to disable foreign key checks")
}
}
}
}
...@@ -161,9 +161,8 @@ func SearchTeams(query *m.SearchTeamsQuery) error { ...@@ -161,9 +161,8 @@ func SearchTeams(query *m.SearchTeamsQuery) error {
sql.WriteString(` order by team.name asc`) sql.WriteString(` order by team.name asc`)
if query.Limit != 0 { if query.Limit != 0 {
sql.WriteString(` limit ? offset ?`)
offset := query.Limit * (query.Page - 1) offset := query.Limit * (query.Page - 1)
params = append(params, query.Limit, offset) sql.WriteString(dialect.LimitOffset(int64(query.Limit), int64(offset)))
} }
if err := x.Sql(sql.String(), params...).Find(&query.Result.Teams); err != nil { if err := x.Sql(sql.String(), params...).Find(&query.Result.Teams); err != nil {
......
...@@ -60,9 +60,15 @@ func getOrgIdForNewUser(cmd *m.CreateUserCommand, sess *DBSession) (int64, error ...@@ -60,9 +60,15 @@ func getOrgIdForNewUser(cmd *m.CreateUserCommand, sess *DBSession) (int64, error
org.Created = time.Now() org.Created = time.Now()
org.Updated = time.Now() org.Updated = time.Now()
if _, err := sess.Insert(&org); err != nil { if org.Id != 0 {
if _, err := sess.InsertId(&org); err != nil {
return 0, err return 0, err
} }
} else {
if _, err := sess.InsertOne(&org); err != nil {
return 0, err
}
}
sess.publishAfterCommit(&events.OrgCreated{ sess.publishAfterCommit(&events.OrgCreated{
Timestamp: org.Created, Timestamp: org.Created,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment