Commit ed6cca61 by Marcus Efraimsson Committed by GitHub

Merge pull request #15051 from ellisvlad/13711_parse_database_config_ipv6_host

Parse database host correctly when using IPv6
parents 577f35f6 9692955d
...@@ -21,6 +21,7 @@ import ( ...@@ -21,6 +21,7 @@ import (
"github.com/grafana/grafana/pkg/services/sqlstore/migrator" "github.com/grafana/grafana/pkg/services/sqlstore/migrator"
"github.com/grafana/grafana/pkg/services/sqlstore/sqlutil" "github.com/grafana/grafana/pkg/services/sqlstore/sqlutil"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
"github.com/go-xorm/xorm" "github.com/go-xorm/xorm"
...@@ -222,13 +223,9 @@ func (ss *SqlStore) buildConnectionString() (string, error) { ...@@ -222,13 +223,9 @@ func (ss *SqlStore) buildConnectionString() (string, error) {
cnnstr += "&tls=custom" cnnstr += "&tls=custom"
} }
case migrator.POSTGRES: case migrator.POSTGRES:
var host, port = "127.0.0.1", "5432" host, port, err := util.SplitIpPort(ss.dbCfg.Host, "5432")
fields := strings.Split(ss.dbCfg.Host, ":") if err != nil {
if len(fields) > 0 && len(strings.TrimSpace(fields[0])) > 0 { return "", err
host = fields[0]
}
if len(fields) > 1 && len(strings.TrimSpace(fields[1])) > 0 {
port = fields[1]
} }
if ss.dbCfg.Pwd == "" { if ss.dbCfg.Pwd == "" {
ss.dbCfg.Pwd = "''" ss.dbCfg.Pwd = "''"
......
package sqlstore
import (
"testing"
. "github.com/smartystreets/goconvey/convey"
"github.com/grafana/grafana/pkg/setting"
)
type sqlStoreTest struct {
name string
dbType string
dbHost string
connStrValues []string
}
var sqlStoreTestCases = []sqlStoreTest{
{
name: "MySQL IPv4",
dbType: "mysql",
dbHost: "1.2.3.4:5678",
connStrValues: []string{"tcp(1.2.3.4:5678)"},
},
{
name: "Postgres IPv4",
dbType: "postgres",
dbHost: "1.2.3.4:5678",
connStrValues: []string{"host=1.2.3.4", "port=5678"},
},
{
name: "Postgres IPv4 (Default Port)",
dbType: "postgres",
dbHost: "1.2.3.4",
connStrValues: []string{"host=1.2.3.4", "port=5432"},
},
{
name: "MySQL IPv4 (Default Port)",
dbType: "mysql",
dbHost: "1.2.3.4",
connStrValues: []string{"tcp(1.2.3.4)"},
},
{
name: "MySQL IPv6",
dbType: "mysql",
dbHost: "[fe80::24e8:31b2:91df:b177]:1234",
connStrValues: []string{"tcp([fe80::24e8:31b2:91df:b177]:1234)"},
},
{
name: "Postgres IPv6",
dbType: "postgres",
dbHost: "[fe80::24e8:31b2:91df:b177]:1234",
connStrValues: []string{"host=fe80::24e8:31b2:91df:b177", "port=1234"},
},
{
name: "MySQL IPv6 (Default Port)",
dbType: "mysql",
dbHost: "::1",
connStrValues: []string{"tcp(::1)"},
},
{
name: "Postgres IPv6 (Default Port)",
dbType: "postgres",
dbHost: "::1",
connStrValues: []string{"host=::1", "port=5432"},
},
}
func TestSqlConnectionString(t *testing.T) {
Convey("Testing SQL Connection Strings", t, func() {
t.Helper()
for _, testCase := range sqlStoreTestCases {
Convey(testCase.name, func() {
sqlstore := &SqlStore{}
sqlstore.Cfg = makeSqlStoreTestConfig(testCase.dbType, testCase.dbHost)
sqlstore.readConfig()
connStr, err := sqlstore.buildConnectionString()
So(err, ShouldBeNil)
for _, connSubStr := range testCase.connStrValues {
So(connStr, ShouldContainSubstring, connSubStr)
}
})
}
})
}
func makeSqlStoreTestConfig(dbType string, host string) *setting.Cfg {
cfg := setting.NewCfg()
sec, _ := cfg.Raw.NewSection("database")
sec.NewKey("type", dbType)
sec.NewKey("host", host)
sec.NewKey("user", "user")
sec.NewKey("name", "test_db")
sec.NewKey("password", "pass")
return cfg
}
...@@ -4,13 +4,13 @@ import ( ...@@ -4,13 +4,13 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"strconv" "strconv"
"strings"
_ "github.com/denisenkom/go-mssqldb" _ "github.com/denisenkom/go-mssqldb"
"github.com/go-xorm/core" "github.com/go-xorm/core"
"github.com/grafana/grafana/pkg/log" "github.com/grafana/grafana/pkg/log"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/tsdb" "github.com/grafana/grafana/pkg/tsdb"
"github.com/grafana/grafana/pkg/util"
) )
func init() { func init() {
...@@ -20,7 +20,10 @@ func init() { ...@@ -20,7 +20,10 @@ func init() {
func newMssqlQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndpoint, error) { func newMssqlQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndpoint, error) {
logger := log.New("tsdb.mssql") logger := log.New("tsdb.mssql")
cnnstr := generateConnectionString(datasource) cnnstr, err := generateConnectionString(datasource)
if err != nil {
return nil, err
}
logger.Debug("getEngine", "connection", cnnstr) logger.Debug("getEngine", "connection", cnnstr)
config := tsdb.SqlQueryEndpointConfiguration{ config := tsdb.SqlQueryEndpointConfiguration{
...@@ -37,7 +40,7 @@ func newMssqlQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndpoin ...@@ -37,7 +40,7 @@ func newMssqlQueryEndpoint(datasource *models.DataSource) (tsdb.TsdbQueryEndpoin
return tsdb.NewSqlQueryEndpoint(&config, &rowTransformer, newMssqlMacroEngine(), logger) return tsdb.NewSqlQueryEndpoint(&config, &rowTransformer, newMssqlMacroEngine(), logger)
} }
func generateConnectionString(datasource *models.DataSource) string { func generateConnectionString(datasource *models.DataSource) (string, error) {
password := "" password := ""
for key, value := range datasource.SecureJsonData.Decrypt() { for key, value := range datasource.SecureJsonData.Decrypt() {
if key == "password" { if key == "password" {
...@@ -46,12 +49,11 @@ func generateConnectionString(datasource *models.DataSource) string { ...@@ -46,12 +49,11 @@ func generateConnectionString(datasource *models.DataSource) string {
} }
} }
hostParts := strings.Split(datasource.Url, ":") server, port, err := util.SplitIpPort(datasource.Url, "1433")
if len(hostParts) < 2 { if err != nil {
hostParts = append(hostParts, "1433") return "", err
} }
server, port := hostParts[0], hostParts[1]
encrypt := datasource.JsonData.Get("encrypt").MustString("false") encrypt := datasource.JsonData.Get("encrypt").MustString("false")
connStr := fmt.Sprintf("server=%s;port=%s;database=%s;user id=%s;password=%s;", connStr := fmt.Sprintf("server=%s;port=%s;database=%s;user id=%s;password=%s;",
server, server,
...@@ -63,7 +65,7 @@ func generateConnectionString(datasource *models.DataSource) string { ...@@ -63,7 +65,7 @@ func generateConnectionString(datasource *models.DataSource) string {
if encrypt != "false" { if encrypt != "false" {
connStr += fmt.Sprintf("encrypt=%s;", encrypt) connStr += fmt.Sprintf("encrypt=%s;", encrypt)
} }
return connStr return connStr, nil
} }
type mssqlRowTransformer struct { type mssqlRowTransformer struct {
......
package util
import (
"net"
)
func SplitIpPort(ipStr string, portDefault string) (ip string, port string, err error) {
ipAddr := net.ParseIP(ipStr)
if ipAddr == nil {
// Port was included
ip, port, err = net.SplitHostPort(ipStr)
if err != nil {
return "", "", err
}
} else {
// No port was included
ip = ipAddr.String()
port = portDefault
}
return ip, port, nil
}
package util
import (
"testing"
. "github.com/smartystreets/goconvey/convey"
)
func TestSplitIpPort(t *testing.T) {
Convey("When parsing an IPv4 without explicit port", t, func() {
ip, port, err := SplitIpPort("1.2.3.4", "5678")
So(err, ShouldEqual, nil)
So(ip, ShouldEqual, "1.2.3.4")
So(port, ShouldEqual, "5678")
})
Convey("When parsing an IPv6 without explicit port", t, func() {
ip, port, err := SplitIpPort("::1", "5678")
So(err, ShouldEqual, nil)
So(ip, ShouldEqual, "::1")
So(port, ShouldEqual, "5678")
})
Convey("When parsing an IPv4 with explicit port", t, func() {
ip, port, err := SplitIpPort("1.2.3.4:56", "78")
So(err, ShouldEqual, nil)
So(ip, ShouldEqual, "1.2.3.4")
So(port, ShouldEqual, "56")
})
Convey("When parsing an IPv6 with explicit port", t, func() {
ip, port, err := SplitIpPort("[::1]:56", "78")
So(err, ShouldEqual, nil)
So(ip, ShouldEqual, "::1")
So(port, ShouldEqual, "56")
})
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment