Commit 25048ebd by Arve Knudsen Committed by GitHub

Chore: Add CloudWatch HTTP API tests (#29691)

* CloudWatch: Add HTTP API tests

Signed-off-by: Arve Knudsen <arve.knudsen@gmail.com>
parent b6efebd9
......@@ -232,7 +232,7 @@ func (hs *HTTPServer) registerRoutes() {
// orgs (admin routes)
apiRoute.Group("/orgs/name/:name", func(orgsRoute routing.RouteRegister) {
orgsRoute.Get("/", Wrap(GetOrgByName))
orgsRoute.Get("/", Wrap(hs.GetOrgByName))
}, reqGrafanaAdmin)
// auth api keys
......
......@@ -15,6 +15,7 @@ import (
"github.com/grafana/grafana/pkg/services/live"
"github.com/grafana/grafana/pkg/services/search"
"github.com/grafana/grafana/pkg/services/shorturls"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/plugins/backendplugin"
......@@ -77,6 +78,7 @@ type HTTPServer struct {
ShortURLService *shorturls.ShortURLService `inject:""`
Live *live.GrafanaLive `inject:""`
ContextHandler *contexthandler.ContextHandler `inject:""`
SQLStore *sqlstore.SQLStore `inject:""`
Listener net.Listener
}
......
......@@ -47,7 +47,7 @@ func (hs *HTTPServer) QueryMetricsV2(c *models.ReqContext, reqDto dtos.MetricReq
if i == 0 && !hasExpr {
ds, err = hs.DatasourceCache.GetDatasource(datasourceID, c.SignedInUser, c.SkipCache)
if err != nil {
hs.log.Debug("Encountered error getting data source", "err", err)
hs.log.Debug("Encountered error getting data source", "err", err, "id", datasourceID)
if errors.Is(err, models.ErrDataSourceAccessDenied) {
return Error(403, "Access denied to data source", err)
}
......
......@@ -22,16 +22,15 @@ func GetOrgByID(c *models.ReqContext) Response {
}
// Get /api/orgs/name/:name
func GetOrgByName(c *models.ReqContext) Response {
query := models.GetOrgByNameQuery{Name: c.Params(":name")}
if err := bus.Dispatch(&query); err != nil {
func (hs *HTTPServer) GetOrgByName(c *models.ReqContext) Response {
org, err := hs.SQLStore.GetOrgByName(c.Params(":name"))
if err != nil {
if errors.Is(err, models.ErrOrgNotFound) {
return Error(404, "Organization not found", err)
}
return Error(500, "Failed to get organization", err)
}
org := query.Result
result := models.OrgDetailsDTO{
Id: org.Id,
Name: org.Name,
......
......@@ -43,6 +43,9 @@ type RouteRegister interface {
// Register iterates over all routes added to the RouteRegister
// and add them to the `Router` pass as an parameter.
Register(Router)
// Reset resets the route register.
Reset()
}
type RegisterNamedMiddleware func(name string) macaron.Handler
......@@ -71,6 +74,16 @@ type routeRegister struct {
groups []*routeRegister
}
func (rr *routeRegister) Reset() {
if rr == nil {
return
}
rr.routes = nil
rr.groups = nil
rr.subfixHandlers = nil
}
func (rr *routeRegister) Insert(pattern string, fn func(RouteRegister), handlers ...macaron.Handler) {
// loop over all groups at current level
for _, g := range rr.groups {
......
package fs
import (
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
)
// CopyFile copies a file from src to dst.
//
// If src and dst files exist, and are the same, then return success. Otherwise, attempt to create a hard link
// between the two files. If that fails, copy the file contents from src to dst.
func CopyFile(src, dst string) (err error) {
absSrc, err := filepath.Abs(src)
if err != nil {
return fmt.Errorf("failed to get absolute path of source file %q: %w", src, err)
}
sfi, err := os.Stat(src)
if err != nil {
err = fmt.Errorf("couldn't stat source file %q: %w", absSrc, err)
return
}
if !sfi.Mode().IsRegular() {
// Cannot copy non-regular files (e.g., directories, symlinks, devices, etc.)
return fmt.Errorf("non-regular source file %s (%q)", absSrc, sfi.Mode().String())
}
dpath := filepath.Dir(dst)
exists, err := Exists(dpath)
if err != nil {
return err
}
if !exists {
err = fmt.Errorf("destination directory doesn't exist: %q", dpath)
return
}
var dfi os.FileInfo
dfi, err = os.Stat(dst)
if err != nil {
if !os.IsNotExist(err) {
return
}
} else {
if !(dfi.Mode().IsRegular()) {
return fmt.Errorf("non-regular destination file %s (%q)", dfi.Name(), dfi.Mode().String())
}
if os.SameFile(sfi, dfi) {
return copyPermissions(sfi.Name(), dfi.Name())
}
}
if err = os.Link(src, dst); err == nil {
return copyPermissions(src, dst)
}
err = copyFileContents(src, dst)
return err
}
// copyFileContents copies the contents of the file named src to the file named
// by dst. The file will be created if it does not already exist. If the
// destination file exists, all it's contents will be replaced by the contents
// of the source file.
func copyFileContents(src, dst string) (err error) {
// Can ignore gosec G304 here, since it's a general file copying function
// nolint:gosec
in, err := os.Open(src)
if err != nil {
return
}
defer func() {
if e := in.Close(); err == nil && e != nil {
err = e
}
}()
out, err := os.Create(dst)
if err != nil {
return
}
defer func() {
if cerr := out.Close(); cerr != nil && err == nil {
err = cerr
}
}()
if _, err = io.Copy(out, in); err != nil {
return
}
if err := out.Sync(); err != nil {
return err
}
return copyPermissions(src, dst)
}
func copyPermissions(src, dst string) error {
sfi, err := os.Lstat(src)
if err != nil {
return err
}
if err := os.Chmod(dst, sfi.Mode()); err != nil {
return err
}
return nil
}
// CopyRecursive copies files and directories recursively.
func CopyRecursive(src, dst string) error {
sfi, err := os.Stat(src)
if err != nil {
return err
}
if !sfi.IsDir() {
return CopyFile(src, dst)
}
if _, err := os.Stat(dst); os.IsNotExist(err) {
if err := os.MkdirAll(dst, sfi.Mode()); err != nil {
return fmt.Errorf("failed to create directory %q: %s", dst, err)
}
}
entries, err := ioutil.ReadDir(src)
if err != nil {
return err
}
for _, entry := range entries {
srcPath := filepath.Join(src, entry.Name())
dstPath := filepath.Join(dst, entry.Name())
srcFi, err := os.Stat(srcPath)
if err != nil {
return err
}
switch srcFi.Mode() & os.ModeType {
case os.ModeDir:
if err := CopyRecursive(srcPath, dstPath); err != nil {
return err
}
case os.ModeSymlink:
link, err := os.Readlink(srcPath)
if err != nil {
return err
}
if err := os.Symlink(link, dstPath); err != nil {
return err
}
default:
if err := CopyFile(srcPath, dstPath); err != nil {
return err
}
}
if srcFi.Mode()&os.ModeSymlink != 0 {
if err := os.Chmod(dstPath, srcFi.Mode()); err != nil {
return err
}
}
}
return nil
}
package fs
import (
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCopyFile(t *testing.T) {
src, err := ioutil.TempFile("", "")
require.NoError(t, err)
t.Cleanup(func() {
err := os.RemoveAll(src.Name())
assert.NoError(t, err)
})
err = ioutil.WriteFile(src.Name(), []byte("Contents"), 0600)
require.NoError(t, err)
dst, err := ioutil.TempFile("", "")
require.NoError(t, err)
t.Cleanup(func() {
err := os.RemoveAll(dst.Name())
assert.NoError(t, err)
})
err = CopyFile(src.Name(), dst.Name())
require.NoError(t, err)
}
func TestCopyFile_Permissions(t *testing.T) {
const perms = os.FileMode(0700)
src, err := ioutil.TempFile("", "")
require.NoError(t, err)
t.Cleanup(func() {
err := os.RemoveAll(src.Name())
assert.NoError(t, err)
})
err = ioutil.WriteFile(src.Name(), []byte("Contents"), 0600)
require.NoError(t, err)
err = os.Chmod(src.Name(), perms)
require.NoError(t, err)
dst, err := ioutil.TempFile("", "")
require.NoError(t, err)
t.Cleanup(func() {
err := os.RemoveAll(dst.Name())
assert.NoError(t, err)
})
err = CopyFile(src.Name(), dst.Name())
require.NoError(t, err)
fi, err := os.Stat(dst.Name())
require.NoError(t, err)
assert.Equal(t, perms, fi.Mode()&os.ModePerm)
}
// Test case where destination directory doesn't exist.
func TestCopyFile_NonExistentDestDir(t *testing.T) {
// nolint:gosec
src, err := ioutil.TempFile("", "")
require.NoError(t, err)
t.Cleanup(func() {
err := os.RemoveAll(src.Name())
assert.NoError(t, err)
})
err = CopyFile(src.Name(), "non-existent/dest")
require.EqualError(t, err, "destination directory doesn't exist: \"non-existent\"")
}
func TestCopyRecursive_NonExistentDest(t *testing.T) {
src, err := ioutil.TempDir("", "")
require.NoError(t, err)
t.Cleanup(func() {
err := os.RemoveAll(src)
assert.NoError(t, err)
})
err = os.MkdirAll(filepath.Join(src, "data"), 0750)
require.NoError(t, err)
// nolint:gosec
err = ioutil.WriteFile(filepath.Join(src, "data", "file.txt"), []byte("Test"), 0644)
require.NoError(t, err)
dstParent, err := ioutil.TempDir("", "")
require.NoError(t, err)
t.Cleanup(func() {
err := os.RemoveAll(dstParent)
assert.NoError(t, err)
})
dst := filepath.Join(dstParent, "dest")
err = CopyRecursive(src, dst)
require.NoError(t, err)
compareDirs(t, src, dst)
}
func TestCopyRecursive_ExistentDest(t *testing.T) {
src, err := ioutil.TempDir("", "")
require.NoError(t, err)
t.Cleanup(func() {
err := os.RemoveAll(src)
assert.NoError(t, err)
})
err = os.MkdirAll(filepath.Join(src, "data"), 0750)
require.NoError(t, err)
// nolint:gosec
err = ioutil.WriteFile(filepath.Join(src, "data", "file.txt"), []byte("Test"), 0644)
require.NoError(t, err)
dst, err := ioutil.TempDir("", "")
require.NoError(t, err)
t.Cleanup(func() {
err := os.RemoveAll(dst)
assert.NoError(t, err)
})
err = CopyRecursive(src, dst)
require.NoError(t, err)
compareDirs(t, src, dst)
}
func compareDirs(t *testing.T, src, dst string) {
sfi, err := os.Stat(src)
require.NoError(t, err)
dfi, err := os.Stat(dst)
require.NoError(t, err)
require.Equal(t, sfi.Mode(), dfi.Mode())
err = filepath.Walk(src, func(srcPath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
relPath := strings.TrimPrefix(srcPath, src)
dstPath := filepath.Join(dst, relPath)
sfi, err := os.Stat(srcPath)
require.NoError(t, err)
dfi, err := os.Stat(dstPath)
require.NoError(t, err)
require.Equal(t, sfi.Mode(), dfi.Mode())
if sfi.IsDir() {
return nil
}
// nolint:gosec
srcData, err := ioutil.ReadFile(srcPath)
require.NoError(t, err)
// nolint:gosec
dstData, err := ioutil.ReadFile(dstPath)
require.NoError(t, err)
require.Equal(t, srcData, dstData)
return nil
})
require.NoError(t, err)
}
......@@ -57,14 +57,12 @@ func TestMiddlewareAuth(t *testing.T) {
middlewareScenario(t, "ReqSignIn true and request with same org provided in query string", func(
t *testing.T, sc *scenarioContext) {
bus.AddHandler("test", func(query *models.GetOrgByNameQuery) error {
query.Result = &models.Org{Id: orgID, Name: "test"}
return nil
})
org, err := sc.sqlStore.CreateOrgWithMember(sc.cfg.AnonymousOrgName, 1)
require.NoError(t, err)
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
sc.fakeReq("GET", fmt.Sprintf("/secure?orgId=%d", orgID)).exec()
sc.fakeReq("GET", fmt.Sprintf("/secure?orgId=%d", org.Id)).exec()
assert.Equal(t, 200, sc.resp.Code)
}, configure)
......
......@@ -329,19 +329,13 @@ func TestMiddlewareContext(t *testing.T) {
})
middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) {
const orgID int64 = 2
bus.AddHandler("test", func(query *models.GetOrgByNameQuery) error {
assert.Equal(t, "test", query.Name)
query.Result = &models.Org{Id: orgID, Name: "test"}
return nil
})
org, err := sc.sqlStore.CreateOrgWithMember(sc.cfg.AnonymousOrgName, 1)
require.NoError(t, err)
sc.fakeReq("GET", "/").exec()
assert.Equal(t, int64(0), sc.context.UserId)
assert.Equal(t, orgID, sc.context.OrgId)
assert.Equal(t, org.Id, sc.context.OrgId)
assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole)
assert.False(t, sc.context.IsSignedIn)
}, func(cfg *setting.Cfg) {
......@@ -572,6 +566,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
}))
ctxHdlr := getContextHandler(t, cfg)
sc.sqlStore = ctxHdlr.SQLStore
sc.contextHandler = ctxHdlr
sc.m.Use(ctxHdlr.Middleware)
sc.m.Use(OrgRedirect(sc.cfg))
......
......@@ -12,6 +12,7 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require"
)
......@@ -31,6 +32,7 @@ type scenarioContext struct {
userAuthTokenService *auth.FakeUserAuthTokenService
remoteCacheService *remotecache.RemoteCache
cfg *setting.Cfg
sqlStore *sqlstore.SQLStore
contextHandler *contexthandler.ContextHandler
req *http.Request
......
......@@ -47,6 +47,18 @@ func Register(descriptor *Descriptor) {
services = append(services, descriptor)
}
// GetService gets the registered service descriptor with a certain name.
// If none is found, nil is returned.
func GetService(name string) *Descriptor {
for _, svc := range services {
if svc.Name == name {
return svc
}
}
return nil
}
func GetServices() []*Descriptor {
slice := getServicesWithOverrides()
......
......@@ -13,7 +13,6 @@ import (
"sync"
"time"
"github.com/facebookgo/inject"
"golang.org/x/sync/errgroup"
"github.com/grafana/grafana/pkg/api"
......@@ -43,7 +42,6 @@ import (
_ "github.com/grafana/grafana/pkg/services/search"
_ "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util/errutil"
)
// Config contains parameters for the New function.
......@@ -75,11 +73,11 @@ func New(cfg Config) (*Server, error) {
version: cfg.Version,
commit: cfg.Commit,
buildBranch: cfg.BuildBranch,
listener: cfg.Listener,
}
if cfg.Listener != nil {
if err := s.init(&cfg); err != nil {
return nil, err
}
if err := s.init(); err != nil {
return nil, err
}
return s, nil
......@@ -96,6 +94,7 @@ type Server struct {
shutdownInProgress bool
isInitialized bool
mtx sync.Mutex
listener net.Listener
configFile string
homePath string
......@@ -108,7 +107,7 @@ type Server struct {
}
// init initializes the server and its services.
func (s *Server) init(cfg *Config) error {
func (s *Server) init() error {
s.mtx.Lock()
defer s.mtx.Unlock()
......@@ -131,25 +130,15 @@ func (s *Server) init(cfg *Config) error {
return err
}
// Initialize services.
for _, service := range services {
if registry.IsDisabled(service.Instance) {
continue
}
if cfg != nil {
if s.listener != nil {
for _, service := range services {
if httpS, ok := service.Instance.(*api.HTTPServer); ok {
// Configure the api.HTTPServer if necessary
// Hopefully we can find a better solution, maybe with a more advanced DI framework, f.ex. Dig?
if cfg.Listener != nil {
s.log.Debug("Using provided listener for HTTP server")
httpS.Listener = cfg.Listener
}
s.log.Debug("Using provided listener for HTTP server")
httpS.Listener = s.listener
}
}
if err := service.Instance.Init(); err != nil {
return errutil.Wrapf(err, "Service init failed")
}
}
return nil
......@@ -158,7 +147,7 @@ func (s *Server) init(cfg *Config) error {
// Run initializes and starts services. This will block until all services have
// exited. To initiate shutdown, call the Shutdown method in another goroutine.
func (s *Server) Run() (err error) {
if err = s.init(nil); err != nil {
if err = s.init(); err != nil {
return
}
......@@ -278,26 +267,7 @@ func (s *Server) buildServiceGraph(services []*registry.Descriptor) error {
localcache.New(5*time.Minute, 10*time.Minute),
s,
}
for _, service := range services {
objs = append(objs, service.Instance)
}
var serviceGraph inject.Graph
// Provide services and their dependencies to the graph.
for _, obj := range objs {
if err := serviceGraph.Provide(&inject.Object{Value: obj}); err != nil {
return errutil.Wrapf(err, "Failed to provide object to the graph")
}
}
// Resolve services and their dependencies.
if err := serviceGraph.Populate(); err != nil {
return errutil.Wrapf(err, "Failed to populate service dependencies")
}
return nil
return registry.BuildServiceGraph(objs, services)
}
// loadConfiguration loads settings and configuration from config files.
......
......@@ -8,12 +8,11 @@ import (
)
func (s *UserAuthTokenService) Run(ctx context.Context) error {
var err error
ticker := time.NewTicker(time.Hour)
maxInactiveLifetime := s.Cfg.LoginMaxInactiveLifetime
maxLifetime := s.Cfg.LoginMaxLifetime
err = s.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func() {
err := s.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func() {
if _, err := s.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime); err != nil {
s.log.Error("An error occurred while deleting expired tokens", "err", err)
}
......
......@@ -114,8 +114,8 @@ func (h *ContextHandler) initContextWithAnonymousUser(ctx *models.ReqContext) bo
return false
}
orgQuery := models.GetOrgByNameQuery{Name: h.Cfg.AnonymousOrgName}
if err := bus.Dispatch(&orgQuery); err != nil {
org, err := h.SQLStore.GetOrgByName(h.Cfg.AnonymousOrgName)
if err != nil {
log.Errorf(3, "Anonymous access organization error: '%s': %s", h.Cfg.AnonymousOrgName, err)
return false
}
......@@ -124,8 +124,8 @@ func (h *ContextHandler) initContextWithAnonymousUser(ctx *models.ReqContext) bo
ctx.AllowAnonymous = true
ctx.SignedInUser = &models.SignedInUser{IsAnonymous: true}
ctx.OrgRole = models.RoleType(h.Cfg.AnonymousOrgRole)
ctx.OrgId = orgQuery.Result.Id
ctx.OrgName = orgQuery.Result.Name
ctx.OrgId = org.Id
ctx.OrgName = org.Name
return true
}
......
......@@ -6,6 +6,7 @@ import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/log"
......@@ -28,7 +29,8 @@ func TestDashboardsAsConfig(t *testing.T) {
t.Run("Should fail if orgs don't exist in the database", func(t *testing.T) {
cfgProvider := configReader{path: appliedDefaults, log: logger}
_, err := cfgProvider.readConfig()
require.Equal(t, errors.Unwrap(err), models.ErrOrgNotFound)
require.Error(t, err)
assert.True(t, errors.Is(err, models.ErrOrgNotFound))
})
for i := 1; i <= 2; i++ {
......
package migrator
import (
"errors"
"time"
_ "github.com/go-sql-driver/mysql"
......@@ -115,7 +114,8 @@ func (mg *Migrator) Start() error {
}
}
return nil
// Make sure migrations are synced
return mg.x.Sync2()
}
func (mg *Migrator) exec(m Migration, sess *xorm.Session) error {
......@@ -126,7 +126,7 @@ func (mg *Migrator) exec(m Migration, sess *xorm.Session) error {
sql, args := condition.SQL(mg.Dialect)
if sql != "" {
mg.Logger.Debug("Executing migration condition sql", "id", m.Id(), "sql", sql, "args", args)
mg.Logger.Debug("Executing migration condition SQL", "id", m.Id(), "sql", sql, "args", args)
results, err := sess.SQL(sql, args...).Query()
if err != nil {
mg.Logger.Error("Executing migration condition failed", "id", m.Id(), "error", err)
......@@ -169,8 +169,8 @@ func (mg *Migrator) inTransaction(callback dbTransactionFunc) error {
}
if err := callback(sess); err != nil {
if rollErr := sess.Rollback(); !errors.Is(err, rollErr) {
return errutil.Wrapf(err, "failed to roll back transaction due to error: %s; initial err: %s", rollErr, err)
if rollErr := sess.Rollback(); rollErr != nil {
return errutil.Wrapf(err, "failed to roll back transaction due to error: %s", rollErr)
}
return err
......
......@@ -125,14 +125,18 @@ func (db *MySQLDialect) CleanDB() error {
defer sess.Close()
for _, table := range tables {
if _, err := sess.Exec("set foreign_key_checks = 0"); err != nil {
return errutil.Wrap("failed to disable foreign key checks", err)
}
if _, err := sess.Exec("drop table " + table.Name + " ;"); err != nil {
return errutil.Wrapf(err, "failed to delete table %q", table.Name)
}
if _, err := sess.Exec("set foreign_key_checks = 1"); err != nil {
return errutil.Wrap("failed to disable foreign key checks", err)
switch table.Name {
case "migration_log":
default:
if _, err := sess.Exec("set foreign_key_checks = 0"); err != nil {
return errutil.Wrap("failed to disable foreign key checks", err)
}
if _, err := sess.Exec("drop table " + table.Name + " ;"); err != nil {
return errutil.Wrapf(err, "failed to delete table %q", table.Name)
}
if _, err := sess.Exec("set foreign_key_checks = 1"); err != nil {
return errutil.Wrap("failed to disable foreign key checks", err)
}
}
}
......
......@@ -151,6 +151,7 @@ func (db *PostgresDialect) TruncateDBTables() error {
switch table.Name {
case "":
continue
case "migration_log":
case "dashboard_acl":
// keep default dashboard permissions
if _, err := sess.Exec(fmt.Sprintf("DELETE FROM %v WHERE dashboard_id != -1 AND org_id != -1;", db.Quote(table.Name))); err != nil {
......
......@@ -100,6 +100,7 @@ func (db *SQLite3) TruncateDBTables() error {
for _, table := range tables {
switch table.Name {
case "migration_log":
case "dashboard_acl":
// keep default dashboard permissions
if _, err := sess.Exec(fmt.Sprintf("DELETE FROM %q WHERE dashboard_id != -1 AND org_id != -1;", table.Name)); err != nil {
......
package sqlstore
import (
"context"
"fmt"
"time"
......@@ -8,9 +9,11 @@ import (
"github.com/grafana/grafana/pkg/events"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
"xorm.io/xorm"
)
const mainOrgName = "Main Org."
// MainOrgName is the name of the main organization.
const MainOrgName = "Main Org."
func init() {
bus.AddHandler("sql", GetOrgById)
......@@ -72,6 +75,20 @@ func GetOrgByName(query *models.GetOrgByNameQuery) error {
return nil
}
// GetOrgByName gets an organization by name.
func (ss *SQLStore) GetOrgByName(name string) (*models.Org, error) {
var org models.Org
exists, err := ss.engine.Where("name=?", name).Get(&org)
if err != nil {
return nil, err
}
if !exists {
return nil, models.ErrOrgNotFound
}
return &org, nil
}
func isOrgNameTaken(name string, existingId int64, sess *DBSession) (bool, error) {
// check if org name is taken
var org models.Org
......@@ -88,34 +105,32 @@ func isOrgNameTaken(name string, existingId int64, sess *DBSession) (bool, error
return false, nil
}
func CreateOrg(cmd *models.CreateOrgCommand) error {
return inTransaction(func(sess *DBSession) error {
if isNameTaken, err := isOrgNameTaken(cmd.Name, 0, sess); err != nil {
func createOrg(name string, userID int64, engine *xorm.Engine) (models.Org, error) {
org := models.Org{
Name: name,
Created: time.Now(),
Updated: time.Now(),
}
if err := inTransactionWithRetryCtx(context.Background(), engine, func(sess *DBSession) error {
if isNameTaken, err := isOrgNameTaken(name, 0, sess); err != nil {
return err
} else if isNameTaken {
return models.ErrOrgNameTaken
}
org := models.Org{
Name: cmd.Name,
Created: time.Now(),
Updated: time.Now(),
}
if _, err := sess.Insert(&org); err != nil {
return err
}
user := models.OrgUser{
OrgId: org.Id,
UserId: cmd.UserId,
UserId: userID,
Role: models.ROLE_ADMIN,
Created: time.Now(),
Updated: time.Now(),
}
_, err := sess.Insert(&user)
cmd.Result = org
sess.publishAfterCommit(&events.OrgCreated{
Timestamp: org.Created,
......@@ -124,7 +139,26 @@ func CreateOrg(cmd *models.CreateOrgCommand) error {
})
return err
})
}, 0); err != nil {
return org, err
}
return org, nil
}
// CreateOrgWithMember creates an organization with a certain name and a certain user as member.
func (ss *SQLStore) CreateOrgWithMember(name string, userID int64) (models.Org, error) {
return createOrg(name, userID, ss.engine)
}
func CreateOrg(cmd *models.CreateOrgCommand) error {
org, err := createOrg(cmd.Name, cmd.UserId, x)
if err != nil {
return err
}
cmd.Result = org
return nil
}
func UpdateOrg(cmd *models.UpdateOrgCommand) error {
......@@ -229,6 +263,52 @@ func verifyExistingOrg(sess *DBSession, orgId int64) error {
return nil
}
func (ss *SQLStore) getOrCreateOrg(sess *DBSession, orgName string) (int64, error) {
var org models.Org
if ss.Cfg.AutoAssignOrg {
has, err := sess.Where("id=?", ss.Cfg.AutoAssignOrgId).Get(&org)
if err != nil {
return 0, err
}
if has {
return org.Id, nil
}
if ss.Cfg.AutoAssignOrgId != 1 {
ss.log.Error("Could not create user: organization ID does not exist", "orgID",
ss.Cfg.AutoAssignOrgId)
return 0, fmt.Errorf("could not create user: organization ID %d does not exist",
ss.Cfg.AutoAssignOrgId)
}
org.Name = MainOrgName
org.Id = int64(ss.Cfg.AutoAssignOrgId)
} else {
org.Name = orgName
}
org.Created = time.Now()
org.Updated = time.Now()
if org.Id != 0 {
if _, err := sess.InsertId(&org); err != nil {
return 0, err
}
} else {
if _, err := sess.InsertOne(&org); err != nil {
return 0, err
}
}
sess.publishAfterCommit(&events.OrgCreated{
Timestamp: org.Created,
Id: org.Id,
Name: org.Name,
})
return org.Id, nil
}
func getOrCreateOrg(sess *DBSession, orgName string) (int64, error) {
var org models.Org
if setting.AutoAssignOrg {
......@@ -247,7 +327,7 @@ func getOrCreateOrg(sess *DBSession, orgName string) (int64, error) {
setting.AutoAssignOrgId)
}
org.Name = mainOrgName
org.Name = MainOrgName
org.Id = int64(setting.AutoAssignOrgId)
} else {
org.Name = orgName
......
......@@ -44,17 +44,7 @@ const InitPriority = registry.High
func init() {
ss := &SQLStore{}
// This change will make xorm use an empty default schema for postgres and
// by that mimic the functionality of how it was functioning before
// xorm's changes above.
xorm.DefaultPostgresSchema = ""
registry.Register(&registry.Descriptor{
Name: ServiceName,
Instance: ss,
InitPriority: InitPriority,
})
ss.Register()
}
type SQLStore struct {
......@@ -69,23 +59,35 @@ type SQLStore struct {
skipEnsureDefaultOrgAndUser bool
}
// Register registers the SQLStore service with the DI system.
func (ss *SQLStore) Register() {
// This change will make xorm use an empty default schema for postgres and
// by that mimic the functionality of how it was functioning before
// xorm's changes above.
xorm.DefaultPostgresSchema = ""
registry.Register(&registry.Descriptor{
Name: ServiceName,
Instance: ss,
InitPriority: InitPriority,
})
}
func (ss *SQLStore) Init() error {
ss.log = log.New("sqlstore")
ss.readConfig()
engine, err := ss.getEngine()
if err != nil {
if err := ss.initEngine(); err != nil {
return errutil.Wrap("failed to connect to database", err)
}
ss.engine = engine
ss.Dialect = migrator.NewDialect(ss.engine)
// temporarily still set global var
x = engine
x = ss.engine
dialect = ss.Dialect
migrator := migrator.NewMigrator(engine)
migrator := migrator.NewMigrator(ss.engine)
migrations.AddMigrations(migrator)
for _, descriptor := range registry.GetServices() {
......@@ -96,7 +98,7 @@ func (ss *SQLStore) Init() error {
}
if err := migrator.Start(); err != nil {
return errutil.Wrap("migration failed", err)
return err
}
// Init repo instances
......@@ -109,6 +111,25 @@ func (ss *SQLStore) Init() error {
ss.addAlertNotificationUidByIdHandler()
ss.addPreferencesQueryAndCommandHandlers()
if err := ss.Reset(); err != nil {
return err
}
// Make sure the changes are synced, so they get shared with eventual other DB connections
if err := ss.Sync(); err != nil {
return err
}
return nil
}
// Sync syncs changes to the database.
func (ss *SQLStore) Sync() error {
return ss.engine.Sync2()
}
// Reset resets database state.
// If default org and user creation is enabled, it will be ensured they exist in the database.
func (ss *SQLStore) Reset() error {
if ss.skipEnsureDefaultOrgAndUser {
return nil
}
......@@ -141,25 +162,27 @@ func (ss *SQLStore) ensureMainOrgAndAdminUser() error {
// ensure admin user
if !ss.Cfg.DisableInitAdminCreation {
ss.log.Debug("Creating default admin user")
cmd := models.CreateUserCommand{
ss.log.Debug("Creating default admin user")
if _, err := ss.createUser(ctx, userCreationArgs{
Login: ss.Cfg.AdminUser,
Email: ss.Cfg.AdminUser + "@localhost",
Password: ss.Cfg.AdminPassword,
IsAdmin: true,
}
if err := bus.DispatchCtx(ctx, &cmd); err != nil {
}, false); err != nil {
return fmt.Errorf("failed to create admin user: %s", err)
}
ss.log.Info("Created default admin", "user", ss.Cfg.AdminUser)
return nil
// Why should we return and not create the default org in this case?
// Returning here breaks tests using anonymous access
// return nil
}
// ensure default org even if default admin user is disabled
if err := inTransactionCtx(ctx, func(sess *DBSession) error {
_, err := getOrCreateOrg(sess, mainOrgName)
if err := inTransactionWithRetryCtx(ctx, ss.engine, func(sess *DBSession) error {
ss.log.Debug("Creating default org", "name", MainOrgName)
_, err := ss.getOrCreateOrg(sess, MainOrgName)
return err
}); err != nil {
}, 0); err != nil {
return fmt.Errorf("failed to create default organization: %w", err)
}
......@@ -253,10 +276,16 @@ func (ss *SQLStore) buildConnectionString() (string, error) {
return cnnstr, nil
}
func (ss *SQLStore) getEngine() (*xorm.Engine, error) {
// initEngine initializes ss.engine.
func (ss *SQLStore) initEngine() error {
if ss.engine != nil {
sqlog.Debug("Already connected to database")
return nil
}
connectionString, err := ss.buildConnectionString()
if err != nil {
return nil, err
return err
}
if ss.Cfg.IsDatabaseMetricsEnabled() {
......@@ -264,10 +293,11 @@ func (ss *SQLStore) getEngine() (*xorm.Engine, error) {
}
sqlog.Info("Connecting to DB", "dbtype", ss.dbCfg.Type)
if ss.dbCfg.Type == migrator.SQLite && strings.HasPrefix(connectionString, "file:") {
if ss.dbCfg.Type == migrator.SQLite && strings.HasPrefix(connectionString, "file:") &&
!strings.HasPrefix(connectionString, "file::memory:") {
exists, err := fs.Exists(ss.dbCfg.Path)
if err != nil {
return nil, errutil.Wrapf(err, "can't check for existence of %q", ss.dbCfg.Path)
return errutil.Wrapf(err, "can't check for existence of %q", ss.dbCfg.Path)
}
const perms = 0640
......@@ -275,15 +305,15 @@ func (ss *SQLStore) getEngine() (*xorm.Engine, error) {
ss.log.Info("Creating SQLite database file", "path", ss.dbCfg.Path)
f, err := os.OpenFile(ss.dbCfg.Path, os.O_CREATE|os.O_RDWR, perms)
if err != nil {
return nil, errutil.Wrapf(err, "failed to create SQLite database file %q", ss.dbCfg.Path)
return errutil.Wrapf(err, "failed to create SQLite database file %q", ss.dbCfg.Path)
}
if err := f.Close(); err != nil {
return nil, errutil.Wrapf(err, "failed to create SQLite database file %q", ss.dbCfg.Path)
return errutil.Wrapf(err, "failed to create SQLite database file %q", ss.dbCfg.Path)
}
} else {
fi, err := os.Lstat(ss.dbCfg.Path)
if err != nil {
return nil, errutil.Wrapf(err, "failed to stat SQLite database file %q", ss.dbCfg.Path)
return errutil.Wrapf(err, "failed to stat SQLite database file %q", ss.dbCfg.Path)
}
m := fi.Mode() & os.ModePerm
if m|perms != perms {
......@@ -294,7 +324,7 @@ func (ss *SQLStore) getEngine() (*xorm.Engine, error) {
}
engine, err := xorm.NewEngine(ss.dbCfg.Type, connectionString)
if err != nil {
return nil, err
return err
}
engine.SetMaxOpenConns(ss.dbCfg.MaxOpenConn)
......@@ -311,9 +341,11 @@ func (ss *SQLStore) getEngine() (*xorm.Engine, error) {
engine.ShowExecTime(true)
}
return engine, nil
ss.engine = engine
return nil
}
// readConfig initializes the SQLStore from its configuration.
func (ss *SQLStore) readConfig() {
sec := ss.Cfg.Raw.Section("database")
......@@ -363,18 +395,29 @@ type ITestDB interface {
Helper()
Fatalf(format string, args ...interface{})
Logf(format string, args ...interface{})
Log(args ...interface{})
}
var testSQLStore *SQLStore
// InitTestDBOpt contains options for InitTestDB.
type InitTestDBOpt struct {
// EnsureDefaultOrgAndUser flags whether to ensure that default org and user exist.
EnsureDefaultOrgAndUser bool
}
// InitTestDB initializes the test DB.
func InitTestDB(t ITestDB) *SQLStore {
func InitTestDB(t ITestDB, opts ...InitTestDBOpt) *SQLStore {
t.Helper()
if testSQLStore == nil {
testSQLStore = &SQLStore{}
testSQLStore.Bus = bus.New()
testSQLStore.CacheService = localcache.New(5*time.Minute, 10*time.Minute)
testSQLStore.skipEnsureDefaultOrgAndUser = false
testSQLStore.skipEnsureDefaultOrgAndUser = true
for _, opt := range opts {
testSQLStore.skipEnsureDefaultOrgAndUser = !opt.EnsureDefaultOrgAndUser
}
dbType := migrator.SQLite
......@@ -423,19 +466,26 @@ func InitTestDB(t ITestDB) *SQLStore {
t.Logf("Cleaning DB")
if err := dialect.CleanDB(); err != nil {
t.Fatalf("Failed to clean test db %v", err)
t.Fatalf("Failed to clean test db: %s", err)
}
if err := testSQLStore.Init(); err != nil {
t.Fatalf("Failed to init test database: %v", err)
t.Fatalf("Failed to init test database: %s", err)
}
t.Log("Successfully initialized test database")
testSQLStore.engine.DatabaseTZ = time.UTC
testSQLStore.engine.TZLocation = time.UTC
return testSQLStore
}
t.Log("Truncating DB tables")
if err := dialect.TruncateDBTables(); err != nil {
t.Fatalf("Failed to truncate test db %v", err)
t.Fatalf("Failed to truncate test db: %s", err)
}
if err := testSQLStore.Reset(); err != nil {
t.Fatalf("Failed to reset SQLStore: %s", err)
}
return testSQLStore
......
......@@ -14,7 +14,8 @@ func SQLite3TestDB() TestDB {
// To run all tests in a local test database, set ConnStr to "grafana_test.db"
return TestDB{
DriverName: "sqlite3",
ConnStr: ":memory:",
// ConnStr specifies an In-memory database shared between connections.
ConnStr: "file::memory:?cache=shared",
}
}
......
......@@ -29,7 +29,7 @@ func TestIntegration_GetUserStats(t *testing.T) {
Email: "admin@test.com",
Name: "Admin",
Login: "admin",
OrgName: mainOrgName,
OrgName: MainOrgName,
IsAdmin: true,
}
err := CreateUser(context.Background(), cmd)
......
......@@ -57,6 +57,140 @@ func getOrgIdForNewUser(sess *DBSession, cmd *models.CreateUserCommand) (int64,
return getOrCreateOrg(sess, orgName)
}
type userCreationArgs struct {
Login string
Email string
Name string
Company string
Password string
IsAdmin bool
IsDisabled bool
EmailVerified bool
OrgID int64
OrgName string
DefaultOrgRole string
}
func (ss *SQLStore) getOrgIDForNewUser(sess *DBSession, args userCreationArgs) (int64, error) {
if ss.Cfg.AutoAssignOrg && args.OrgID != 0 {
if err := verifyExistingOrg(sess, args.OrgID); err != nil {
return -1, err
}
return args.OrgID, nil
}
orgName := args.OrgName
if orgName == "" {
orgName = util.StringsFallback2(args.Email, args.Login)
}
return ss.getOrCreateOrg(sess, orgName)
}
// createUser creates a user in the database.
func (ss *SQLStore) createUser(ctx context.Context, args userCreationArgs, skipOrgSetup bool) (models.User, error) {
var user models.User
if err := inTransactionWithRetryCtx(ctx, ss.engine, func(sess *DBSession) error {
var orgID int64 = -1
if !skipOrgSetup {
var err error
orgID, err = ss.getOrgIDForNewUser(sess, args)
if err != nil {
return err
}
}
if args.Email == "" {
args.Email = args.Login
}
exists, err := sess.Where("email=? OR login=?", args.Email, args.Login).Get(&models.User{})
if err != nil {
return err
}
if exists {
return models.ErrUserAlreadyExists
}
// create user
user = models.User{
Email: args.Email,
Name: args.Name,
Login: args.Login,
Company: args.Company,
IsAdmin: args.IsAdmin,
IsDisabled: args.IsDisabled,
OrgId: orgID,
EmailVerified: args.EmailVerified,
Created: time.Now(),
Updated: time.Now(),
LastSeenAt: time.Now().AddDate(-10, 0, 0),
}
salt, err := util.GetRandomString(10)
if err != nil {
return err
}
user.Salt = salt
rands, err := util.GetRandomString(10)
if err != nil {
return err
}
user.Rands = rands
if len(args.Password) > 0 {
encodedPassword, err := util.EncodePassword(args.Password, user.Salt)
if err != nil {
return err
}
user.Password = encodedPassword
}
sess.UseBool("is_admin")
if _, err := sess.Insert(&user); err != nil {
return err
}
sess.publishAfterCommit(&events.UserCreated{
Timestamp: user.Created,
Id: user.Id,
Name: user.Name,
Login: user.Login,
Email: user.Email,
})
// create org user link
if !skipOrgSetup {
orgUser := models.OrgUser{
OrgId: orgID,
UserId: user.Id,
Role: models.ROLE_ADMIN,
Created: time.Now(),
Updated: time.Now(),
}
if ss.Cfg.AutoAssignOrg && !user.IsAdmin {
if len(args.DefaultOrgRole) > 0 {
orgUser.Role = models.RoleType(args.DefaultOrgRole)
} else {
orgUser.Role = models.RoleType(ss.Cfg.AutoAssignOrgRole)
}
}
if _, err = sess.Insert(&orgUser); err != nil {
return err
}
}
return nil
}, 0); err != nil {
return user, err
}
return user, nil
}
func CreateUser(ctx context.Context, cmd *models.CreateUserCommand) error {
return inTransactionCtx(ctx, func(sess *DBSession) error {
orgId, err := getOrgIdForNewUser(sess, cmd)
......
......@@ -331,6 +331,10 @@ type Cfg struct {
Quota QuotaSettings
DefaultTheme string
AutoAssignOrg bool
AutoAssignOrgId int
AutoAssignOrgRole string
}
// IsExpressionsEnabled returns whether the expressions feature is enabled.
......@@ -1125,9 +1129,12 @@ func readUserSettings(iniFile *ini.File, cfg *Cfg) error {
users := iniFile.Section("users")
AllowUserSignUp = users.Key("allow_sign_up").MustBool(true)
AllowUserOrgCreate = users.Key("allow_org_create").MustBool(true)
AutoAssignOrg = users.Key("auto_assign_org").MustBool(true)
AutoAssignOrgId = users.Key("auto_assign_org_id").MustInt(1)
AutoAssignOrgRole = users.Key("auto_assign_org_role").In("Editor", []string{"Editor", "Admin", "Viewer"})
cfg.AutoAssignOrg = users.Key("auto_assign_org").MustBool(true)
AutoAssignOrg = cfg.AutoAssignOrg
cfg.AutoAssignOrgId = users.Key("auto_assign_org_id").MustInt(1)
AutoAssignOrgId = cfg.AutoAssignOrgId
cfg.AutoAssignOrgRole = users.Key("auto_assign_org_role").In("Editor", []string{"Editor", "Admin", "Viewer"})
AutoAssignOrgRole = cfg.AutoAssignOrgRole
VerifyEmailEnabled = users.Key("verify_email_enabled").MustBool(false)
LoginHint = valueAsString(users, "login_hint", "")
......
# Integration tests
This directory contains Grafana server integration tests.
# API integration tests
This directory contains Grafana HTTP API integration tests.
package metrics
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"os"
"path/filepath"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/cloudwatch/cloudwatchiface"
"github.com/aws/aws-sdk-go/service/cloudwatchlogs/cloudwatchlogsiface"
"github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/server"
"github.com/grafana/grafana/pkg/tsdb"
"github.com/grafana/grafana/pkg/tsdb/cloudwatch"
cwapi "github.com/aws/aws-sdk-go/service/cloudwatch"
"github.com/grafana/grafana/pkg/api/dtos"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/fs"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/ini.v1"
)
func TestQueryCloudWatchMetrics(t *testing.T) {
grafDir, cfgPath := createGrafDir(t)
sqlStore := setUpDatabase(t, grafDir)
addr := startGrafana(t, grafDir, cfgPath, sqlStore)
origNewCWClient := cloudwatch.NewCWClient
t.Cleanup(func() {
cloudwatch.NewCWClient = origNewCWClient
})
var client cloudwatch.FakeCWClient
cloudwatch.NewCWClient = func(sess *session.Session) cloudwatchiface.CloudWatchAPI {
return client
}
t.Run("Custom metrics", func(t *testing.T) {
client = cloudwatch.FakeCWClient{
Metrics: []*cwapi.Metric{
{
MetricName: aws.String("Test_MetricName"),
Dimensions: []*cwapi.Dimension{
{
Name: aws.String("Test_DimensionName"),
},
},
},
},
}
req := dtos.MetricRequest{
Queries: []*simplejson.Json{
simplejson.NewFromAny(map[string]interface{}{
"type": "metricFindQuery",
"subtype": "metrics",
"region": "us-east-1",
"namespace": "custom",
"datasourceId": 1,
}),
},
}
tr := makeCWRequest(t, req, addr)
assert.Equal(t, tsdb.Response{
Results: map[string]*tsdb.QueryResult{
"A": {
RefId: "A",
Meta: simplejson.NewFromAny(map[string]interface{}{
"rowCount": float64(1),
}),
Tables: []*tsdb.Table{
{
Columns: []tsdb.TableColumn{
{
Text: "text",
},
{
Text: "value",
},
},
Rows: []tsdb.RowValues{
{
"Test_MetricName",
"Test_MetricName",
},
},
},
},
},
},
}, tr)
})
}
func TestQueryCloudWatchLogs(t *testing.T) {
grafDir, cfgPath := createGrafDir(t)
sqlStore := setUpDatabase(t, grafDir)
addr := startGrafana(t, grafDir, cfgPath, sqlStore)
origNewCWLogsClient := cloudwatch.NewCWLogsClient
t.Cleanup(func() {
cloudwatch.NewCWLogsClient = origNewCWLogsClient
})
var client cloudwatch.FakeCWLogsClient
cloudwatch.NewCWLogsClient = func(sess *session.Session) cloudwatchlogsiface.CloudWatchLogsAPI {
return client
}
t.Run("Describe log groups", func(t *testing.T) {
client = cloudwatch.FakeCWLogsClient{}
req := dtos.MetricRequest{
Queries: []*simplejson.Json{
simplejson.NewFromAny(map[string]interface{}{
"type": "logAction",
"subtype": "DescribeLogGroups",
"region": "us-east-1",
"datasourceId": 1,
}),
},
}
tr := makeCWRequest(t, req, addr)
dataFrames := tsdb.NewDecodedDataFrames(data.Frames{
&data.Frame{
Name: "logGroups",
Fields: []*data.Field{
data.NewField("logGroupName", nil, []*string{}),
},
Meta: &data.FrameMeta{
PreferredVisualization: "logs",
},
},
})
// Have to call this so that dataFrames.encoded is non-nil, for the comparison
// In the future we should use gocmp instead and ignore this field
_, err := dataFrames.Encoded()
require.NoError(t, err)
assert.Equal(t, tsdb.Response{
Results: map[string]*tsdb.QueryResult{
"A": {
RefId: "A",
Dataframes: dataFrames,
},
},
}, tr)
})
}
func makeCWRequest(t *testing.T, req dtos.MetricRequest, addr string) tsdb.Response {
t.Helper()
buf := bytes.Buffer{}
enc := json.NewEncoder(&buf)
err := enc.Encode(&req)
require.NoError(t, err)
u := fmt.Sprintf("http://%s/api/ds/query", addr)
t.Logf("Making POST request to %s", u)
// nolint:gosec
resp, err := http.Post(u, "application/json", &buf)
require.NoError(t, err)
require.NotNil(t, resp)
t.Cleanup(func() {
err := resp.Body.Close()
assert.NoError(t, err)
})
buf = bytes.Buffer{}
_, err = io.Copy(&buf, resp.Body)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
var tr tsdb.Response
err = json.Unmarshal(buf.Bytes(), &tr)
require.NoError(t, err)
return tr
}
func createGrafDir(t *testing.T) (string, string) {
t.Helper()
tmpDir, err := ioutil.TempDir("", "")
require.NoError(t, err)
t.Cleanup(func() {
err := os.RemoveAll(tmpDir)
assert.NoError(t, err)
})
rootDir := filepath.Join("..", "..", "..", "..")
cfgDir := filepath.Join(tmpDir, "conf")
err = os.MkdirAll(cfgDir, 0750)
require.NoError(t, err)
dataDir := filepath.Join(tmpDir, "data")
// nolint:gosec
err = os.MkdirAll(dataDir, 0750)
require.NoError(t, err)
logsDir := filepath.Join(tmpDir, "logs")
pluginsDir := filepath.Join(tmpDir, "plugins")
publicDir := filepath.Join(tmpDir, "public")
err = os.MkdirAll(publicDir, 0750)
require.NoError(t, err)
emailsDir := filepath.Join(publicDir, "emails")
err = fs.CopyRecursive(filepath.Join(rootDir, "public", "emails"), emailsDir)
require.NoError(t, err)
provDir := filepath.Join(cfgDir, "provisioning")
provDSDir := filepath.Join(provDir, "datasources")
err = os.MkdirAll(provDSDir, 0750)
require.NoError(t, err)
provNotifiersDir := filepath.Join(provDir, "notifiers")
err = os.MkdirAll(provNotifiersDir, 0750)
require.NoError(t, err)
provPluginsDir := filepath.Join(provDir, "plugins")
err = os.MkdirAll(provPluginsDir, 0750)
require.NoError(t, err)
provDashboardsDir := filepath.Join(provDir, "dashboards")
err = os.MkdirAll(provDashboardsDir, 0750)
require.NoError(t, err)
cfg := ini.Empty()
dfltSect := cfg.Section("")
_, err = dfltSect.NewKey("app_mode", "development")
require.NoError(t, err)
pathsSect, err := cfg.NewSection("paths")
require.NoError(t, err)
_, err = pathsSect.NewKey("data", dataDir)
require.NoError(t, err)
_, err = pathsSect.NewKey("logs", logsDir)
require.NoError(t, err)
_, err = pathsSect.NewKey("plugins", pluginsDir)
require.NoError(t, err)
logSect, err := cfg.NewSection("log")
require.NoError(t, err)
_, err = logSect.NewKey("level", "debug")
require.NoError(t, err)
serverSect, err := cfg.NewSection("server")
require.NoError(t, err)
_, err = serverSect.NewKey("port", "0")
require.NoError(t, err)
anonSect, err := cfg.NewSection("auth.anonymous")
require.NoError(t, err)
_, err = anonSect.NewKey("enabled", "true")
require.NoError(t, err)
cfgPath := filepath.Join(cfgDir, "test.ini")
err = cfg.SaveTo(cfgPath)
require.NoError(t, err)
err = fs.CopyFile(filepath.Join(rootDir, "conf", "defaults.ini"), filepath.Join(cfgDir, "defaults.ini"))
require.NoError(t, err)
return tmpDir, cfgPath
}
func startGrafana(t *testing.T, grafDir, cfgPath string, sqlStore *sqlstore.SQLStore) string {
t.Helper()
origSQLStore := registry.GetService(sqlstore.ServiceName)
t.Cleanup(func() {
registry.Register(origSQLStore)
})
registry.Register(&registry.Descriptor{
Name: sqlstore.ServiceName,
Instance: sqlStore,
InitPriority: sqlstore.InitPriority,
})
t.Logf("Registered SQL store %p", sqlStore)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
server, err := server.New(server.Config{
ConfigFile: cfgPath,
HomePath: grafDir,
Listener: listener,
})
require.NoError(t, err)
t.Cleanup(func() {
// Have to reset the route register between tests, since it doesn't get re-created
server.HTTPServer.RouteRegister.Reset()
})
go func() {
// When the server runs, it will also build and initialize the service graph
if err := server.Run(); err != nil {
t.Log("Server exited uncleanly", "error", err)
}
}()
t.Cleanup(func() {
server.Shutdown("")
})
// Wait for Grafana to be ready
addr := listener.Addr().String()
resp, err := http.Get(fmt.Sprintf("http://%s/api/health", addr))
require.NoError(t, err)
require.NotNil(t, resp)
t.Cleanup(func() {
err := resp.Body.Close()
assert.NoError(t, err)
})
require.Equal(t, 200, resp.StatusCode)
t.Logf("Grafana is listening on %s", addr)
return addr
}
func setUpDatabase(t *testing.T, grafDir string) *sqlstore.SQLStore {
t.Helper()
sqlStore := sqlstore.InitTestDB(t, sqlstore.InitTestDBOpt{
EnsureDefaultOrgAndUser: true,
})
// We need the main org, since it's used for anonymous access
org, err := sqlStore.GetOrgByName(sqlstore.MainOrgName)
require.NoError(t, err)
require.NotNil(t, org)
err = sqlStore.WithDbSession(context.Background(), func(sess *sqlstore.DBSession) error {
_, err := sess.Insert(&models.DataSource{
Id: 1,
// This will be the ID of the main org
OrgId: 2,
Name: "Test",
Type: "cloudwatch",
Created: time.Now(),
Updated: time.Now(),
})
return err
})
require.NoError(t, err)
// Make sure changes are synced with other goroutines
err = sqlStore.Sync()
require.NoError(t, err)
return sqlStore
}
......@@ -199,7 +199,7 @@ func (e *cloudWatchExecutor) getCWClient(region string) (cloudwatchiface.CloudWa
if err != nil {
return nil, err
}
return newCWClient(sess), nil
return NewCWClient(sess), nil
}
func (e *cloudWatchExecutor) getCWLogsClient(region string) (cloudwatchlogsiface.CloudWatchLogsAPI, error) {
......@@ -208,7 +208,7 @@ func (e *cloudWatchExecutor) getCWLogsClient(region string) (cloudwatchlogsiface
return nil, err
}
logsClient := newCWLogsClient(sess)
logsClient := NewCWLogsClient(sess)
return logsClient, nil
}
......@@ -452,10 +452,10 @@ func isTerminated(queryStatus string) bool {
return queryStatus == "Complete" || queryStatus == "Cancelled" || queryStatus == "Failed" || queryStatus == "Timeout"
}
// newCWClient is a CloudWatch client factory.
// NewCWClient is a CloudWatch client factory.
//
// Stubbable by tests.
var newCWClient = func(sess *session.Session) cloudwatchiface.CloudWatchAPI {
var NewCWClient = func(sess *session.Session) cloudwatchiface.CloudWatchAPI {
client := cloudwatch.New(sess)
client.Handlers.Send.PushFront(func(r *request.Request) {
r.HTTPRequest.Header.Set("User-Agent", fmt.Sprintf("Grafana/%s", setting.BuildVersion))
......@@ -464,10 +464,10 @@ var newCWClient = func(sess *session.Session) cloudwatchiface.CloudWatchAPI {
return client
}
// newCWLogsClient is a CloudWatch logs client factory.
// NewCWLogsClient is a CloudWatch logs client factory.
//
// Stubbable by tests.
var newCWLogsClient = func(sess *session.Session) cloudwatchlogsiface.CloudWatchLogsAPI {
var NewCWLogsClient = func(sess *session.Session) cloudwatchlogsiface.CloudWatchLogsAPI {
client := cloudwatchlogs.New(sess)
client.Handlers.Send.PushFront(func(r *request.Request) {
r.HTTPRequest.Header.Set("User-Agent", fmt.Sprintf("Grafana/%s", setting.BuildVersion))
......
......@@ -19,19 +19,19 @@ import (
)
func TestQuery_DescribeLogGroups(t *testing.T) {
origNewCWLogsClient := newCWLogsClient
origNewCWLogsClient := NewCWLogsClient
t.Cleanup(func() {
newCWLogsClient = origNewCWLogsClient
NewCWLogsClient = origNewCWLogsClient
})
var cli fakeCWLogsClient
var cli FakeCWLogsClient
newCWLogsClient = func(sess *session.Session) cloudwatchlogsiface.CloudWatchLogsAPI {
NewCWLogsClient = func(sess *session.Session) cloudwatchlogsiface.CloudWatchLogsAPI {
return cli
}
t.Run("Empty log group name prefix", func(t *testing.T) {
cli = fakeCWLogsClient{
cli = FakeCWLogsClient{
logGroups: cloudwatchlogs.DescribeLogGroupsOutput{
LogGroups: []*cloudwatchlogs.LogGroup{
{
......@@ -84,7 +84,7 @@ func TestQuery_DescribeLogGroups(t *testing.T) {
})
t.Run("Non-empty log group name prefix", func(t *testing.T) {
cli = fakeCWLogsClient{
cli = FakeCWLogsClient{
logGroups: cloudwatchlogs.DescribeLogGroupsOutput{
LogGroups: []*cloudwatchlogs.LogGroup{
{
......@@ -138,18 +138,18 @@ func TestQuery_DescribeLogGroups(t *testing.T) {
}
func TestQuery_GetLogGroupFields(t *testing.T) {
origNewCWLogsClient := newCWLogsClient
origNewCWLogsClient := NewCWLogsClient
t.Cleanup(func() {
newCWLogsClient = origNewCWLogsClient
NewCWLogsClient = origNewCWLogsClient
})
var cli fakeCWLogsClient
var cli FakeCWLogsClient
newCWLogsClient = func(sess *session.Session) cloudwatchlogsiface.CloudWatchLogsAPI {
NewCWLogsClient = func(sess *session.Session) cloudwatchlogsiface.CloudWatchLogsAPI {
return cli
}
cli = fakeCWLogsClient{
cli = FakeCWLogsClient{
logGroupFields: cloudwatchlogs.GetLogGroupFieldsOutput{
LogGroupFields: []*cloudwatchlogs.LogGroupField{
{
......@@ -213,19 +213,19 @@ func TestQuery_GetLogGroupFields(t *testing.T) {
}
func TestQuery_StartQuery(t *testing.T) {
origNewCWLogsClient := newCWLogsClient
origNewCWLogsClient := NewCWLogsClient
t.Cleanup(func() {
newCWLogsClient = origNewCWLogsClient
NewCWLogsClient = origNewCWLogsClient
})
var cli fakeCWLogsClient
var cli FakeCWLogsClient
newCWLogsClient = func(sess *session.Session) cloudwatchlogsiface.CloudWatchLogsAPI {
NewCWLogsClient = func(sess *session.Session) cloudwatchlogsiface.CloudWatchLogsAPI {
return cli
}
t.Run("invalid time range", func(t *testing.T) {
cli = fakeCWLogsClient{
cli = FakeCWLogsClient{
logGroupFields: cloudwatchlogs.GetLogGroupFieldsOutput{
LogGroupFields: []*cloudwatchlogs.LogGroupField{
{
......@@ -271,7 +271,7 @@ func TestQuery_StartQuery(t *testing.T) {
t.Run("valid time range", func(t *testing.T) {
const refID = "A"
cli = fakeCWLogsClient{
cli = FakeCWLogsClient{
logGroupFields: cloudwatchlogs.GetLogGroupFieldsOutput{
LogGroupFields: []*cloudwatchlogs.LogGroupField{
{
......@@ -336,18 +336,18 @@ func TestQuery_StartQuery(t *testing.T) {
}
func TestQuery_StopQuery(t *testing.T) {
origNewCWLogsClient := newCWLogsClient
origNewCWLogsClient := NewCWLogsClient
t.Cleanup(func() {
newCWLogsClient = origNewCWLogsClient
NewCWLogsClient = origNewCWLogsClient
})
var cli fakeCWLogsClient
var cli FakeCWLogsClient
newCWLogsClient = func(sess *session.Session) cloudwatchlogsiface.CloudWatchLogsAPI {
NewCWLogsClient = func(sess *session.Session) cloudwatchlogsiface.CloudWatchLogsAPI {
return cli
}
cli = fakeCWLogsClient{
cli = FakeCWLogsClient{
logGroupFields: cloudwatchlogs.GetLogGroupFieldsOutput{
LogGroupFields: []*cloudwatchlogs.LogGroupField{
{
......@@ -405,19 +405,19 @@ func TestQuery_StopQuery(t *testing.T) {
}
func TestQuery_GetQueryResults(t *testing.T) {
origNewCWLogsClient := newCWLogsClient
origNewCWLogsClient := NewCWLogsClient
t.Cleanup(func() {
newCWLogsClient = origNewCWLogsClient
NewCWLogsClient = origNewCWLogsClient
})
var cli fakeCWLogsClient
var cli FakeCWLogsClient
newCWLogsClient = func(sess *session.Session) cloudwatchlogsiface.CloudWatchLogsAPI {
NewCWLogsClient = func(sess *session.Session) cloudwatchlogsiface.CloudWatchLogsAPI {
return cli
}
const refID = "A"
cli = fakeCWLogsClient{
cli = FakeCWLogsClient{
queryResults: cloudwatchlogs.GetQueryResultsOutput{
Results: [][]*cloudwatchlogs.ResultField{
{
......
......@@ -20,20 +20,20 @@ import (
)
func TestQuery_Metrics(t *testing.T) {
origNewCWClient := newCWClient
origNewCWClient := NewCWClient
t.Cleanup(func() {
newCWClient = origNewCWClient
NewCWClient = origNewCWClient
})
var client fakeCWClient
var client FakeCWClient
newCWClient = func(sess *session.Session) cloudwatchiface.CloudWatchAPI {
NewCWClient = func(sess *session.Session) cloudwatchiface.CloudWatchAPI {
return client
}
t.Run("Custom metrics", func(t *testing.T) {
client = fakeCWClient{
metrics: []*cloudwatch.Metric{
client = FakeCWClient{
Metrics: []*cloudwatch.Metric{
{
MetricName: aws.String("Test_MetricName"),
Dimensions: []*cloudwatch.Dimension{
......@@ -89,8 +89,8 @@ func TestQuery_Metrics(t *testing.T) {
})
t.Run("Dimension keys for custom metrics", func(t *testing.T) {
client = fakeCWClient{
metrics: []*cloudwatch.Metric{
client = FakeCWClient{
Metrics: []*cloudwatch.Metric{
{
MetricName: aws.String("Test_MetricName"),
Dimensions: []*cloudwatch.Dimension{
......
......@@ -43,46 +43,46 @@ func fakeDataSource(cfgs ...fakeDataSourceCfg) *models.DataSource {
}
}
type fakeCWLogsClient struct {
type FakeCWLogsClient struct {
cloudwatchlogsiface.CloudWatchLogsAPI
logGroups cloudwatchlogs.DescribeLogGroupsOutput
logGroupFields cloudwatchlogs.GetLogGroupFieldsOutput
queryResults cloudwatchlogs.GetQueryResultsOutput
}
func (m fakeCWLogsClient) GetQueryResultsWithContext(ctx context.Context, input *cloudwatchlogs.GetQueryResultsInput, option ...request.Option) (*cloudwatchlogs.GetQueryResultsOutput, error) {
func (m FakeCWLogsClient) GetQueryResultsWithContext(ctx context.Context, input *cloudwatchlogs.GetQueryResultsInput, option ...request.Option) (*cloudwatchlogs.GetQueryResultsOutput, error) {
return &m.queryResults, nil
}
func (m fakeCWLogsClient) StartQueryWithContext(ctx context.Context, input *cloudwatchlogs.StartQueryInput, option ...request.Option) (*cloudwatchlogs.StartQueryOutput, error) {
func (m FakeCWLogsClient) StartQueryWithContext(ctx context.Context, input *cloudwatchlogs.StartQueryInput, option ...request.Option) (*cloudwatchlogs.StartQueryOutput, error) {
return &cloudwatchlogs.StartQueryOutput{
QueryId: aws.String("abcd-efgh-ijkl-mnop"),
}, nil
}
func (m fakeCWLogsClient) StopQueryWithContext(ctx context.Context, input *cloudwatchlogs.StopQueryInput, option ...request.Option) (*cloudwatchlogs.StopQueryOutput, error) {
func (m FakeCWLogsClient) StopQueryWithContext(ctx context.Context, input *cloudwatchlogs.StopQueryInput, option ...request.Option) (*cloudwatchlogs.StopQueryOutput, error) {
return &cloudwatchlogs.StopQueryOutput{
Success: aws.Bool(true),
}, nil
}
func (m fakeCWLogsClient) DescribeLogGroupsWithContext(ctx context.Context, input *cloudwatchlogs.DescribeLogGroupsInput, option ...request.Option) (*cloudwatchlogs.DescribeLogGroupsOutput, error) {
func (m FakeCWLogsClient) DescribeLogGroupsWithContext(ctx context.Context, input *cloudwatchlogs.DescribeLogGroupsInput, option ...request.Option) (*cloudwatchlogs.DescribeLogGroupsOutput, error) {
return &m.logGroups, nil
}
func (m fakeCWLogsClient) GetLogGroupFieldsWithContext(ctx context.Context, input *cloudwatchlogs.GetLogGroupFieldsInput, option ...request.Option) (*cloudwatchlogs.GetLogGroupFieldsOutput, error) {
func (m FakeCWLogsClient) GetLogGroupFieldsWithContext(ctx context.Context, input *cloudwatchlogs.GetLogGroupFieldsInput, option ...request.Option) (*cloudwatchlogs.GetLogGroupFieldsOutput, error) {
return &m.logGroupFields, nil
}
type fakeCWClient struct {
type FakeCWClient struct {
cloudwatchiface.CloudWatchAPI
metrics []*cloudwatch.Metric
Metrics []*cloudwatch.Metric
}
func (c fakeCWClient) ListMetricsPages(input *cloudwatch.ListMetricsInput, fn func(*cloudwatch.ListMetricsOutput, bool) bool) error {
func (c FakeCWClient) ListMetricsPages(input *cloudwatch.ListMetricsInput, fn func(*cloudwatch.ListMetricsOutput, bool) bool) error {
fn(&cloudwatch.ListMetricsOutput{
Metrics: c.metrics,
Metrics: c.Metrics,
}, true)
return nil
}
......
package tsdb
import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/grafana/grafana/pkg/components/null"
......@@ -42,6 +44,121 @@ type QueryResult struct {
Dataframes DataFrames `json:"dataframes"`
}
// UnmarshalJSON deserializes a QueryResult from JSON.
//
// Deserialization support is required by tests.
func (r *QueryResult) UnmarshalJSON(b []byte) error {
m := map[string]interface{}{}
// TODO: Use JSON decoder
if err := json.Unmarshal(b, &m); err != nil {
return err
}
refID, ok := m["refId"].(string)
if !ok {
return fmt.Errorf("can't decode field refId - not a string")
}
var meta *simplejson.Json
if m["meta"] != nil {
mm, ok := m["meta"].(map[string]interface{})
if !ok {
return fmt.Errorf("can't decode field meta - not a JSON object")
}
meta = simplejson.NewFromAny(mm)
}
var series TimeSeriesSlice
/* TODO
if m["series"] != nil {
}
*/
var tables []*Table
if m["tables"] != nil {
ts, ok := m["tables"].([]interface{})
if !ok {
return fmt.Errorf("can't decode field tables - not an array of Tables")
}
for _, ti := range ts {
tm, ok := ti.(map[string]interface{})
if !ok {
return fmt.Errorf("can't decode field tables - not an array of Tables")
}
var columns []TableColumn
cs, ok := tm["columns"].([]interface{})
if !ok {
return fmt.Errorf("can't decode field tables - not an array of Tables")
}
for _, ci := range cs {
cm, ok := ci.(map[string]interface{})
if !ok {
return fmt.Errorf("can't decode field tables - not an array of Tables")
}
val, ok := cm["text"].(string)
if !ok {
return fmt.Errorf("can't decode field tables - not an array of Tables")
}
columns = append(columns, TableColumn{Text: val})
}
rs, ok := tm["rows"].([]interface{})
if !ok {
return fmt.Errorf("can't decode field tables - not an array of Tables")
}
var rows []RowValues
for _, ri := range rs {
vals, ok := ri.([]interface{})
if !ok {
return fmt.Errorf("can't decode field tables - not an array of Tables")
}
rows = append(rows, vals)
}
tables = append(tables, &Table{
Columns: columns,
Rows: rows,
})
}
}
var dfs *dataFrames
if m["dataframes"] != nil {
raw, ok := m["dataframes"].([]interface{})
if !ok {
return fmt.Errorf("can't decode field dataframes - not an array of byte arrays")
}
var encoded [][]byte
for _, ra := range raw {
encS, ok := ra.(string)
if !ok {
return fmt.Errorf("can't decode field dataframes - not an array of byte arrays")
}
enc, err := base64.StdEncoding.DecodeString(encS)
if err != nil {
return fmt.Errorf("can't decode field dataframes - not an array of arrow frames")
}
encoded = append(encoded, enc)
}
decoded, err := data.UnmarshalArrowFrames(encoded)
if err != nil {
return err
}
dfs = &dataFrames{
decoded: decoded,
encoded: encoded,
}
}
r.RefId = refID
r.Meta = meta
r.Series = series
r.Tables = tables
if dfs != nil {
r.Dataframes = dfs
}
return nil
}
type TimeSeries struct {
Name string `json:"name"`
Points TimeSeriesPoints `json:"points"`
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment