Commit 8bfed750 by Torkel Ödegaard

More work on sql schema and migrations, starting to get somewhere

parent 68a77c40
...@@ -10,6 +10,7 @@ var ( ...@@ -10,6 +10,7 @@ var (
ErrAccountNotFound = errors.New("Account not found") ErrAccountNotFound = errors.New("Account not found")
) )
// Directly mapped to db schema, Do not change field names lighly
type Account struct { type Account struct {
Id int64 Id int64
Login string `xorm:"UNIQUE NOT NULL"` Login string `xorm:"UNIQUE NOT NULL"`
......
package migrations package migrations
type migration struct { import (
desc string "fmt"
sqlite string "strings"
mysql string )
verifyTable string
const (
POSTGRES = "postgres"
SQLITE = "sqlite3"
MYSQL = "mysql"
)
type Migration interface {
Sql(dialect Dialect) string
} }
type columnType string type ColumnType string
const ( const (
DB_TYPE_STRING columnType = "String" DB_TYPE_STRING ColumnType = "String"
) )
func (m *migration) getSql(dbType string) string { type MigrationBase struct {
switch dbType { desc string
case "mysql": }
type RawSqlMigration struct {
MigrationBase
sqlite string
mysql string
}
func (m *RawSqlMigration) Sql(dialect Dialect) string {
switch dialect.DriverName() {
case MYSQL:
return m.mysql return m.mysql
case "sqlite3": case SQLITE:
return m.sqlite return m.sqlite
} }
panic("db type not supported") panic("db type not supported")
} }
type migrationBuilder struct { func (m *RawSqlMigration) Sqlite(sql string) *RawSqlMigration {
migration *migration m.sqlite = sql
return m
}
func (m *RawSqlMigration) Mysql(sql string) *RawSqlMigration {
m.mysql = sql
return m
}
func (m *RawSqlMigration) Desc(desc string) *RawSqlMigration {
m.desc = desc
return m
}
type AddColumnMigration struct {
MigrationBase
tableName string
columnName string
columnType ColumnType
length int
}
func (m *AddColumnMigration) Table(tableName string) *AddColumnMigration {
m.tableName = tableName
return m
}
func (m *AddColumnMigration) Length(length int) *AddColumnMigration {
m.length = length
return m
}
func (m *AddColumnMigration) Column(columnName string) *AddColumnMigration {
m.columnName = columnName
return m
}
func (m *AddColumnMigration) Type(columnType ColumnType) *AddColumnMigration {
m.columnType = columnType
return m
}
func (m *AddColumnMigration) Sql(dialect Dialect) string {
return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", m.tableName, m.columnName, dialect.ToDBTypeSql(m.columnType, m.length))
}
func (m *AddColumnMigration) Desc(desc string) *AddColumnMigration {
m.desc = desc
return m
} }
func (b *migrationBuilder) sqlite(sql string) *migrationBuilder { type AddIndexMigration struct {
b.migration.sqlite = sql MigrationBase
return b tableName string
columns string
indexName string
} }
func (b *migrationBuilder) mysql(sql string) *migrationBuilder { func (m *AddIndexMigration) Name(name string) *AddIndexMigration {
b.migration.mysql = sql m.indexName = name
return b return m
} }
func (b *migrationBuilder) verifyTable(name string) *migrationBuilder { func (m *AddIndexMigration) Table(tableName string) *AddIndexMigration {
b.migration.verifyTable = name m.tableName = tableName
return b return m
} }
func (b *migrationBuilder) add() *migrationBuilder { func (m *AddIndexMigration) Columns(columns ...string) *AddIndexMigration {
migrationList = append(migrationList, b.migration) m.columns = strings.Join(columns, ",")
return b return m
} }
func (b *migrationBuilder) desc(desc string) *migrationBuilder { func (m *AddIndexMigration) Sql(dialect Dialect) string {
b.migration = &migration{desc: desc} return fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s(%s)", m.indexName, m.tableName, m.columns)
return b
} }
package sqlsyntax package migrations
import "fmt"
type Dialect interface { type Dialect interface {
DBType() string DriverName() string
ToDBTypeSql(columnType ColumnType, length int) string
TableCheckSql(tableName string) (string, []interface{}) TableCheckSql(tableName string) (string, []interface{})
} }
...@@ -11,12 +14,30 @@ type Sqlite3 struct { ...@@ -11,12 +14,30 @@ type Sqlite3 struct {
type Mysql struct { type Mysql struct {
} }
func (db *Sqlite3) DBType() string { func (db *Sqlite3) DriverName() string {
return "sqlite3" return SQLITE
}
func (db *Mysql) DriverName() string {
return MYSQL
} }
func (db *Mysql) DBType() string { func (db *Sqlite3) ToDBTypeSql(columnType ColumnType, length int) string {
return "mysql" switch columnType {
case DB_TYPE_STRING:
return "TEXT"
}
panic("Unsupported db type")
}
func (db *Mysql) ToDBTypeSql(columnType ColumnType, length int) string {
switch columnType {
case DB_TYPE_STRING:
return fmt.Sprintf("NVARCHAR(%d)", length)
}
panic("Unsupported db type")
} }
func (db *Sqlite3) TableCheckSql(tableName string) (string, []interface{}) { func (db *Sqlite3) TableCheckSql(tableName string) (string, []interface{}) {
......
package migrations package migrations
import ( import (
"errors"
"fmt"
"github.com/torkelo/grafana-pro/pkg/services/sqlstore/sqlsyntax"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/go-xorm/xorm" "github.com/go-xorm/xorm"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/torkelo/grafana-pro/pkg/log"
) )
var x *xorm.Engine var x *xorm.Engine
var dialect sqlsyntax.Dialect var dialect Dialect
func getSchemaVersion() (int, error) { func getSchemaVersion() (int, error) {
exists, err := x.IsTableExist(new(SchemaVersion)) exists, err := x.IsTableExist(new(SchemaVersion))
...@@ -36,14 +32,16 @@ func getSchemaVersion() (int, error) { ...@@ -36,14 +32,16 @@ func getSchemaVersion() (int, error) {
func setEngineAndDialect(engine *xorm.Engine) { func setEngineAndDialect(engine *xorm.Engine) {
x = engine x = engine
switch x.DriverName() { switch x.DriverName() {
case "mysql": case MYSQL:
dialect = new(sqlsyntax.Mysql) dialect = new(Mysql)
case "sqlite3": case SQLITE:
dialect = new(sqlsyntax.Sqlite3) dialect = new(Sqlite3)
} }
} }
func StartMigration(engine *xorm.Engine) error { func StartMigration(engine *xorm.Engine) error {
log.Info("Starting database schema migration: DB: %v", engine.DriverName())
setEngineAndDialect(engine) setEngineAndDialect(engine)
_, err := getSchemaVersion() _, err := getSchemaVersion()
...@@ -60,9 +58,9 @@ func StartMigration(engine *xorm.Engine) error { ...@@ -60,9 +58,9 @@ func StartMigration(engine *xorm.Engine) error {
return nil return nil
} }
func execMigration(m *migration) error { func execMigration(m Migration) error {
err := inTransaction(func(sess *xorm.Session) error { err := inTransaction(func(sess *xorm.Session) error {
_, err := sess.Exec(m.getSql(x.DriverName())) _, err := sess.Exec(m.Sql(dialect))
if err != nil { if err != nil {
return err return err
} }
...@@ -73,17 +71,6 @@ func execMigration(m *migration) error { ...@@ -73,17 +71,6 @@ func execMigration(m *migration) error {
return err return err
} }
return verifyMigration(m)
}
func verifyMigration(m *migration) error {
if m.verifyTable != "" {
sqlStr, args := dialect.TableCheckSql(m.verifyTable)
results, err := x.Query(sqlStr, args...)
if err != nil || len(results) == 0 {
return errors.New(fmt.Sprintf("Verify failed: table %v does not exist", m.verifyTable))
}
}
return nil return nil
} }
......
package migrations package migrations
var migrationList []*migration var migrationList []Migration
// Id int64
// Login string `xorm:"UNIQUE NOT NULL"`
// Email string `xorm:"UNIQUE NOT NULL"`
// Name string
// FullName string
// Password string
// IsAdmin bool
// Salt string `xorm:"VARCHAR(10)"`
// Company string
// NextDashboardId int
// UsingAccountId int64
// Created time.Time
// Updated time.Time
func init() { func init() {
new(migrationBuilder). // ------------------------------
// ------------------------------ addMigration(new(RawSqlMigration).Desc("Create account table").
desc("Create account table"). Sqlite(`
sqlite(`
CREATE TABLE account ( CREATE TABLE account (
id INTEGER PRIMARY KEY AUTOINCREMENT id INTEGER PRIMARY KEY AUTOINCREMENT,
login TEXT NOT NULL,
email TEXT NOT NULL
) )
`). `).
mysql(` Mysql(`
CREATE TABLE account ( CREATE TABLE account (
id BIGINT NOT NULL AUTO_INCREMENT, PRIMARY KEY (id) id BIGINT NOT NULL AUTO_INCREMENT, PRIMARY KEY (id),
login VARCHAR(255) NOT NULL,
email VARCHAR(255) NOT NULL
) )
`). `))
verifyTable("account").add()
// ------------------------------ // ------------------------------
// desc("Add name column to account table"). addMigration(new(AddIndexMigration).
// table("account").addColumn("name").colType(DB_TYPE_STRING) Name("UIX_account_login").Table("account").Columns("login"))
// sqlite("ALTER TABLE account ADD COLUMN name TEXT"). // ------------------------------
// mysql("ALTER TABLE account ADD COLUMN name NVARCHAR(255)"). addMigration(new(AddColumnMigration).Desc("Add name column").
Table("account").Column("name").Type(DB_TYPE_STRING).Length(255))
}
func addMigration(m Migration) {
migrationList = append(migrationList, m)
} }
type SchemaVersion struct { type SchemaVersion struct {
......
...@@ -2,6 +2,7 @@ package migrations ...@@ -2,6 +2,7 @@ package migrations
import ( import (
"fmt" "fmt"
"strings"
"testing" "testing"
"github.com/go-xorm/xorm" "github.com/go-xorm/xorm"
...@@ -27,10 +28,12 @@ func cleanDB(x *xorm.Engine) { ...@@ -27,10 +28,12 @@ func cleanDB(x *xorm.Engine) {
} }
} }
func TestMigrationsSqlite(t *testing.T) { var indexTypes = []string{"Unknown", "", "UNIQUE"}
func TestMigrations(t *testing.T) {
testDBs := [][]string{ testDBs := [][]string{
//[]string{"mysql", "grafana:password@tcp(localhost:3306)/grafana_tests?charset=utf8"},
[]string{"sqlite3", ":memory:"}, []string{"sqlite3", ":memory:"},
[]string{"mysql", "grafana:password@tcp(localhost:3306)/grafana_tests?charset=utf8"},
} }
for _, testDB := range testDBs { for _, testDB := range testDBs {
...@@ -43,13 +46,28 @@ func TestMigrationsSqlite(t *testing.T) { ...@@ -43,13 +46,28 @@ func TestMigrationsSqlite(t *testing.T) {
cleanDB(x) cleanDB(x)
} }
StartMigration(x) err = StartMigration(x)
So(err, ShouldBeNil)
tables, err := x.DBMetas() tables, err := x.DBMetas()
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(len(tables), ShouldEqual, 2) So(len(tables), ShouldEqual, 2)
}) fmt.Printf("\nDB Schema after migration: table count: %v\n", len(tables))
for _, table := range tables {
fmt.Printf("\nTable: %v \n", table.Name)
for _, column := range table.Columns() {
fmt.Printf("\t %v \n", column.String(x.Dialect()))
}
if len(table.Indexes) > 0 {
fmt.Printf("\n\tIndexes:\n")
for _, index := range table.Indexes {
fmt.Printf("\t %v (%v) %v \n", index.Name, strings.Join(index.Cols, ","), indexTypes[index.Type])
}
}
}
})
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment