Commit 81436100 by bergquist

bus: support multiple dispatch in one transaction

this makes it possible to run multiple DispatchCtx
in one transaction. The TransactionManager will
start/end the transaction and pass the dbsession
in the context.Context variable
parent e33d1870
......@@ -12,21 +12,51 @@ type Msg interface{}
var ErrHandlerNotFound = errors.New("handler not found")
type TransactionManager interface {
Begin(ctx context.Context) (context.Context, error)
End(ctx context.Context, err error) error
}
type Bus interface {
Dispatch(msg Msg) error
DispatchCtx(ctx context.Context, msg Msg) error
Publish(msg Msg) error
// InTransaction starts a transaction and store it in the context.
// The caller can then pass a function with multiple DispatchCtx calls that
// all will be executed in the same transaction. InTransaction will rollback if the
// callback returns an error.s
InTransaction(ctx context.Context, fn func(ctx context.Context) error) error
AddHandler(handler HandlerFunc)
AddCtxHandler(handler HandlerFunc)
AddEventListener(handler HandlerFunc)
AddWildcardListener(handler HandlerFunc)
// SetTransactionManager allows the user to replace the internal
// noop TransactionManager that is responsible for manageing
// transactions in `InTransaction`
SetTransactionManager(tm TransactionManager)
}
func (b *InProcBus) InTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
ctxWithTran, err := b.transactionManager.Begin(ctx)
if err != nil {
return err
}
err = fn(ctxWithTran)
b.transactionManager.End(ctxWithTran, err)
return err
}
type InProcBus struct {
handlers map[string]HandlerFunc
listeners map[string][]HandlerFunc
wildcardListeners []HandlerFunc
transactionManager TransactionManager
}
// temp stuff, not sure how to handle bus instance, and init yet
......@@ -37,6 +67,9 @@ func New() Bus {
bus.handlers = make(map[string]HandlerFunc)
bus.listeners = make(map[string][]HandlerFunc)
bus.wildcardListeners = make([]HandlerFunc, 0)
bus.transactionManager = &NoopTransactionManager{}
return bus
}
......@@ -45,6 +78,14 @@ func GetBus() Bus {
return globalBus
}
func SetTransactionManager(tm TransactionManager) {
globalBus.SetTransactionManager(tm)
}
func (b *InProcBus) SetTransactionManager(tm TransactionManager) {
b.transactionManager = tm
}
func (b *InProcBus) DispatchCtx(ctx context.Context, msg Msg) error {
var msgName = reflect.TypeOf(msg).Elem().Name()
......@@ -167,6 +208,15 @@ func Publish(msg Msg) error {
return globalBus.Publish(msg)
}
func InTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
return globalBus.InTransaction(ctx, fn)
}
func ClearBusHandlers() {
globalBus = New()
}
type NoopTransactionManager struct{}
func (*NoopTransactionManager) Begin(ctx context.Context) (context.Context, error) { return ctx, nil }
func (*NoopTransactionManager) End(ctx context.Context, err error) error { return err }
......@@ -3,6 +3,7 @@ package notifiers
import (
"github.com/grafana/grafana/pkg/components/simplejson"
m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/alerting"
)
......
package sqlstore
import (
"context"
"reflect"
"time"
......@@ -29,10 +30,35 @@ func inTransaction(callback dbTransactionFunc) error {
return inTransactionWithRetry(callback, 0)
}
func startSession(ctx context.Context) *DBSession {
value := ctx.Value(ContextSessionName)
var sess *xorm.Session
sess, ok := value.(*xorm.Session)
if !ok {
return newSession()
}
old := newSession()
old.Session = sess
return old
}
func withDbSession(ctx context.Context, callback dbTransactionFunc) error {
sess := startSession(ctx)
return callback(sess)
}
func inTransactionWithRetry(callback dbTransactionFunc, retry int) error {
return inTransactionWithRetryCtx(context.Background(), callback, retry)
}
func inTransactionWithRetryCtx(ctx context.Context, callback dbTransactionFunc, retry int) error {
var err error
sess := newSession()
sess := startSession(ctx)
defer sess.Close()
if err = sess.Begin(); err != nil {
......
package sqlstore
import (
"context"
"errors"
"fmt"
"net/url"
"os"
......@@ -35,6 +37,8 @@ var (
sqlog log.Logger = log.New("sqlstore")
)
const ContextSessionName = "db-session"
func init() {
registry.Register(&registry.Descriptor{
Name: "SqlStore",
......@@ -45,6 +49,7 @@ func init() {
type SqlStore struct {
Cfg *setting.Cfg `inject:""`
Bus bus.Bus `inject:""`
dbCfg DatabaseConfig
engine *xorm.Engine
......@@ -77,6 +82,10 @@ func (ss *SqlStore) Init() error {
// Init repo instances
annotations.SetRepository(&SqlAnnotationRepo{})
ss.Bus.SetTransactionManager(&SQLTransactionManager{
engine: ss.engine,
})
// ensure admin user
if ss.skipEnsureAdmin {
return nil
......@@ -85,10 +94,47 @@ func (ss *SqlStore) Init() error {
return ss.ensureAdminUser()
}
type SQLTransactionManager struct {
engine *xorm.Engine
}
func (stm *SQLTransactionManager) Begin(ctx context.Context) (context.Context, error) {
sess := stm.engine.NewSession()
err := sess.Begin()
if err != nil {
return ctx, err
}
withValue := context.WithValue(ctx, ContextSessionName, sess)
return withValue, nil
}
func (stm *SQLTransactionManager) End(ctx context.Context, err error) error {
value := ctx.Value(ContextSessionName)
sess, ok := value.(*xorm.Session)
if !ok {
return errors.New("context is missing transaction")
}
if err != nil {
sess.Rollback()
return err
}
defer sess.Close()
return sess.Commit()
}
func (ss *SqlStore) ensureAdminUser() error {
systemUserCountQuery := m.GetSystemUserCountStatsQuery{}
if err := bus.Dispatch(&systemUserCountQuery); err != nil {
err := bus.InTransaction(context.Background(), func(ctx context.Context) error {
return bus.DispatchCtx(ctx, &systemUserCountQuery)
})
if err != nil {
return fmt.Errorf("Could not determine if admin user exists: %v", err)
}
......@@ -240,6 +286,7 @@ func (ss *SqlStore) readConfig() {
func InitTestDB(t *testing.T) *SqlStore {
sqlstore := &SqlStore{}
sqlstore.skipEnsureAdmin = true
sqlstore.Bus = bus.New()
dbType := migrator.SQLITE
......
package sqlstore
import (
"context"
"time"
"github.com/grafana/grafana/pkg/bus"
......@@ -13,6 +14,7 @@ func init() {
bus.AddHandler("sql", GetDataSourceAccessStats)
bus.AddHandler("sql", GetAdminStats)
bus.AddHandler("sql", GetSystemUserCountStats)
bus.AddCtxHandler("sql", GetSystemUserCountStatsCtx)
}
var activeUserTimeLimit = time.Hour * 24 * 30
......@@ -133,6 +135,22 @@ func GetAdminStats(query *m.GetAdminStatsQuery) error {
return err
}
func GetSystemUserCountStatsCtx(ctx context.Context, query *m.GetSystemUserCountStatsQuery) error {
return withDbSession(ctx, func(sess *DBSession) error {
var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user")
var stats m.SystemUserCountStats
_, err := sess.SQL(rawSql).Get(&stats)
if err != nil {
return err
}
query.Result = &stats
return err
})
}
func GetSystemUserCountStats(query *m.GetSystemUserCountStatsQuery) error {
var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user")
var stats m.SystemUserCountStats
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment