Commit 8cd93f0b by Weeco Committed by Marcus Efraimsson

Datasource: Add custom headers on alerting queries (#19508)

* Add custom headers on alerting queries

Reference issue #15381

Signed-off-by: Martin Schneppenheim <martin.schneppenheim@rewe-digital.com>

* Fix datasource transport tests

* Migrate decrypting header test to models pkg

* Check correct header

* Add HTTP transport test

Fixes #15381
parent a427ff7f
...@@ -121,28 +121,6 @@ func (proxy *DataSourceProxy) addTraceFromHeaderValue(span opentracing.Span, hea ...@@ -121,28 +121,6 @@ func (proxy *DataSourceProxy) addTraceFromHeaderValue(span opentracing.Span, hea
} }
} }
func (proxy *DataSourceProxy) useCustomHeaders(req *http.Request) {
decryptSdj := proxy.ds.SecureJsonData.Decrypt()
index := 1
for {
headerNameSuffix := fmt.Sprintf("httpHeaderName%d", index)
headerValueSuffix := fmt.Sprintf("httpHeaderValue%d", index)
if key := proxy.ds.JsonData.Get(headerNameSuffix).MustString(); key != "" {
if val, ok := decryptSdj[headerValueSuffix]; ok {
// remove if exists
if req.Header.Get(key) != "" {
req.Header.Del(key)
}
req.Header.Add(key, val)
logger.Debug("Using custom header ", "CustomHeaders", key)
}
} else {
break
}
index += 1
}
}
func (proxy *DataSourceProxy) getDirector() func(req *http.Request) { func (proxy *DataSourceProxy) getDirector() func(req *http.Request) {
return func(req *http.Request) { return func(req *http.Request) {
req.URL.Scheme = proxy.targetUrl.Scheme req.URL.Scheme = proxy.targetUrl.Scheme
...@@ -171,11 +149,6 @@ func (proxy *DataSourceProxy) getDirector() func(req *http.Request) { ...@@ -171,11 +149,6 @@ func (proxy *DataSourceProxy) getDirector() func(req *http.Request) {
req.Header.Add("Authorization", util.GetBasicAuthHeader(proxy.ds.BasicAuthUser, proxy.ds.DecryptedBasicAuthPassword())) req.Header.Add("Authorization", util.GetBasicAuthHeader(proxy.ds.BasicAuthUser, proxy.ds.DecryptedBasicAuthPassword()))
} }
// Lookup and use custom headers
if proxy.ds.SecureJsonData != nil {
proxy.useCustomHeaders(req)
}
dsAuth := req.Header.Get("X-DS-Authorization") dsAuth := req.Header.Get("X-DS-Authorization")
if len(dsAuth) > 0 { if len(dsAuth) > 0 {
req.Header.Del("X-DS-Authorization") req.Header.Del("X-DS-Authorization")
......
...@@ -17,7 +17,6 @@ import ( ...@@ -17,7 +17,6 @@ import (
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/login/social"
m "github.com/grafana/grafana/pkg/models" m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/plugins"
...@@ -331,37 +330,6 @@ func TestDSRouteRule(t *testing.T) { ...@@ -331,37 +330,6 @@ func TestDSRouteRule(t *testing.T) {
}) })
}) })
Convey("When proxying a data source with custom headers specified", func() {
plugin := &plugins.DataSourcePlugin{}
encryptedData, err := util.Encrypt([]byte(`Bearer xf5yhfkpsnmgo`), setting.SecretKey)
ds := &m.DataSource{
Type: m.DS_PROMETHEUS,
Url: "http://prometheus:9090",
JsonData: simplejson.NewFromAny(map[string]interface{}{
"httpHeaderName1": "Authorization",
}),
SecureJsonData: map[string][]byte{
"httpHeaderValue1": encryptedData,
},
}
ctx := &m.ReqContext{}
proxy := NewDataSourceProxy(ds, plugin, ctx, "", &setting.Cfg{})
requestURL, _ := url.Parse("http://grafana.com/sub")
req := http.Request{URL: requestURL, Header: make(http.Header)}
proxy.getDirector()(&req)
if err != nil {
log.Fatal(4, err.Error())
}
Convey("Match header value after decryption", func() {
So(req.Header.Get("Authorization"), ShouldEqual, "Bearer xf5yhfkpsnmgo")
})
})
Convey("When proxying a custom datasource", func() { Convey("When proxying a custom datasource", func() {
plugin := &plugins.DataSourcePlugin{} plugin := &plugins.DataSourcePlugin{}
ds := &m.DataSource{ ds := &m.DataSource{
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt"
"net" "net"
"net/http" "net/http"
"sync" "sync"
...@@ -17,10 +18,25 @@ type proxyTransportCache struct { ...@@ -17,10 +18,25 @@ type proxyTransportCache struct {
sync.Mutex sync.Mutex
} }
// dataSourceTransport implements http.RoundTripper (https://golang.org/pkg/net/http/#RoundTripper)
type dataSourceTransport struct {
headers map[string]string
transport *http.Transport
}
// RoundTrip executes a single HTTP transaction, returning a Response for the provided Request.
func (d *dataSourceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for key, value := range d.headers {
req.Header.Set(key, value)
}
return d.transport.RoundTrip(req)
}
type cachedTransport struct { type cachedTransport struct {
updated time.Time updated time.Time
*http.Transport *dataSourceTransport
} }
var ptc = proxyTransportCache{ var ptc = proxyTransportCache{
...@@ -40,12 +56,12 @@ func (ds *DataSource) GetHttpClient() (*http.Client, error) { ...@@ -40,12 +56,12 @@ func (ds *DataSource) GetHttpClient() (*http.Client, error) {
}, nil }, nil
} }
func (ds *DataSource) GetHttpTransport() (*http.Transport, error) { func (ds *DataSource) GetHttpTransport() (*dataSourceTransport, error) {
ptc.Lock() ptc.Lock()
defer ptc.Unlock() defer ptc.Unlock()
if t, present := ptc.cache[ds.Id]; present && ds.Updated.Equal(t.updated) { if t, present := ptc.cache[ds.Id]; present && ds.Updated.Equal(t.updated) {
return t.Transport, nil return t.dataSourceTransport, nil
} }
tlsConfig, err := ds.GetTLSConfig() tlsConfig, err := ds.GetTLSConfig()
...@@ -55,6 +71,8 @@ func (ds *DataSource) GetHttpTransport() (*http.Transport, error) { ...@@ -55,6 +71,8 @@ func (ds *DataSource) GetHttpTransport() (*http.Transport, error) {
tlsConfig.Renegotiation = tls.RenegotiateFreelyAsClient tlsConfig.Renegotiation = tls.RenegotiateFreelyAsClient
// Create transport which adds all
customHeaders := ds.getCustomHeaders()
transport := &http.Transport{ transport := &http.Transport{
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
...@@ -68,12 +86,17 @@ func (ds *DataSource) GetHttpTransport() (*http.Transport, error) { ...@@ -68,12 +86,17 @@ func (ds *DataSource) GetHttpTransport() (*http.Transport, error) {
IdleConnTimeout: 90 * time.Second, IdleConnTimeout: 90 * time.Second,
} }
dsTransport := &dataSourceTransport{
headers: customHeaders,
transport: transport,
}
ptc.cache[ds.Id] = cachedTransport{ ptc.cache[ds.Id] = cachedTransport{
Transport: transport, dataSourceTransport: dsTransport,
updated: ds.Updated, updated: ds.Updated,
} }
return transport, nil return dsTransport, nil
} }
func (ds *DataSource) GetTLSConfig() (*tls.Config, error) { func (ds *DataSource) GetTLSConfig() (*tls.Config, error) {
...@@ -110,3 +133,32 @@ func (ds *DataSource) GetTLSConfig() (*tls.Config, error) { ...@@ -110,3 +133,32 @@ func (ds *DataSource) GetTLSConfig() (*tls.Config, error) {
return tlsConfig, nil return tlsConfig, nil
} }
// getCustomHeaders returns a map with all the to be set headers
// The map key represents the HeaderName and the value represents this header's value
func (ds *DataSource) getCustomHeaders() map[string]string {
headers := make(map[string]string)
if ds.JsonData == nil {
return headers
}
decrypted := ds.SecureJsonData.Decrypt()
index := 1
for {
headerNameSuffix := fmt.Sprintf("httpHeaderName%d", index)
headerValueSuffix := fmt.Sprintf("httpHeaderValue%d", index)
key := ds.JsonData.Get(headerNameSuffix).MustString()
if key == "" {
// No (more) header values are available
break
}
if val, ok := decrypted[headerValueSuffix]; ok {
headers[key] = val
}
index++
}
return headers
}
package models package models
import ( import (
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing" "testing"
"time" "time"
...@@ -31,13 +35,13 @@ func TestDataSourceCache(t *testing.T) { ...@@ -31,13 +35,13 @@ func TestDataSourceCache(t *testing.T) {
So(t2, ShouldEqual, t1) So(t2, ShouldEqual, t1)
}) })
Convey("Should verify TLS by default", func() { Convey("Should verify TLS by default", func() {
So(t1.TLSClientConfig.InsecureSkipVerify, ShouldEqual, false) So(t1.transport.TLSClientConfig.InsecureSkipVerify, ShouldEqual, false)
}) })
Convey("Should have no TLS client certificate configured", func() { Convey("Should have no TLS client certificate configured", func() {
So(len(t1.TLSClientConfig.Certificates), ShouldEqual, 0) So(len(t1.transport.TLSClientConfig.Certificates), ShouldEqual, 0)
}) })
Convey("Should have no user-supplied TLS CA onfigured", func() { Convey("Should have no user-supplied TLS CA onfigured", func() {
So(t1.TLSClientConfig.RootCAs, ShouldBeNil) So(t1.transport.TLSClientConfig.RootCAs, ShouldBeNil)
}) })
}) })
...@@ -62,13 +66,13 @@ func TestDataSourceCache(t *testing.T) { ...@@ -62,13 +66,13 @@ func TestDataSourceCache(t *testing.T) {
So(err, ShouldBeNil) So(err, ShouldBeNil)
Convey("Should verify TLS by default", func() { Convey("Should verify TLS by default", func() {
So(t1.TLSClientConfig.InsecureSkipVerify, ShouldEqual, false) So(t1.transport.TLSClientConfig.InsecureSkipVerify, ShouldEqual, false)
}) })
Convey("Should have no TLS client certificate configured", func() { Convey("Should have no TLS client certificate configured", func() {
So(len(t1.TLSClientConfig.Certificates), ShouldEqual, 0) So(len(t1.transport.TLSClientConfig.Certificates), ShouldEqual, 0)
}) })
Convey("Should have no user-supplied TLS CA configured", func() { Convey("Should have no user-supplied TLS CA configured", func() {
So(t1.TLSClientConfig.RootCAs, ShouldBeNil) So(t1.transport.TLSClientConfig.RootCAs, ShouldBeNil)
}) })
ds.JsonData = nil ds.JsonData = nil
...@@ -79,7 +83,7 @@ func TestDataSourceCache(t *testing.T) { ...@@ -79,7 +83,7 @@ func TestDataSourceCache(t *testing.T) {
So(err, ShouldBeNil) So(err, ShouldBeNil)
Convey("Should have no user-supplied TLS CA configured after the update", func() { Convey("Should have no user-supplied TLS CA configured after the update", func() {
So(t2.TLSClientConfig.RootCAs, ShouldBeNil) So(t2.transport.TLSClientConfig.RootCAs, ShouldBeNil)
}) })
}) })
...@@ -110,10 +114,10 @@ func TestDataSourceCache(t *testing.T) { ...@@ -110,10 +114,10 @@ func TestDataSourceCache(t *testing.T) {
So(err, ShouldBeNil) So(err, ShouldBeNil)
Convey("Should verify TLS by default", func() { Convey("Should verify TLS by default", func() {
So(tr.TLSClientConfig.InsecureSkipVerify, ShouldEqual, false) So(tr.transport.TLSClientConfig.InsecureSkipVerify, ShouldEqual, false)
}) })
Convey("Should have a TLS client certificate configured", func() { Convey("Should have a TLS client certificate configured", func() {
So(len(tr.TLSClientConfig.Certificates), ShouldEqual, 1) So(len(tr.transport.TLSClientConfig.Certificates), ShouldEqual, 1)
}) })
}) })
...@@ -139,10 +143,10 @@ func TestDataSourceCache(t *testing.T) { ...@@ -139,10 +143,10 @@ func TestDataSourceCache(t *testing.T) {
So(err, ShouldBeNil) So(err, ShouldBeNil)
Convey("Should verify TLS by default", func() { Convey("Should verify TLS by default", func() {
So(tr.TLSClientConfig.InsecureSkipVerify, ShouldEqual, false) So(tr.transport.TLSClientConfig.InsecureSkipVerify, ShouldEqual, false)
}) })
Convey("Should have a TLS CA configured", func() { Convey("Should have a TLS CA configured", func() {
So(len(tr.TLSClientConfig.RootCAs.Subjects()), ShouldEqual, 1) So(len(tr.transport.TLSClientConfig.RootCAs.Subjects()), ShouldEqual, 1)
}) })
}) })
...@@ -163,7 +167,67 @@ func TestDataSourceCache(t *testing.T) { ...@@ -163,7 +167,67 @@ func TestDataSourceCache(t *testing.T) {
So(err, ShouldBeNil) So(err, ShouldBeNil)
Convey("Should skip TLS verification", func() { Convey("Should skip TLS verification", func() {
So(tr.TLSClientConfig.InsecureSkipVerify, ShouldEqual, true) So(tr.transport.TLSClientConfig.InsecureSkipVerify, ShouldEqual, true)
})
})
Convey("When caching a datasource proxy with custom headers specified", t, func() {
clearCache()
json := simplejson.NewFromAny(map[string]interface{}{
"httpHeaderName1": "Authorization",
})
encryptedData, err := util.Encrypt([]byte(`Bearer xf5yhfkpsnmgo`), setting.SecretKey)
if err != nil {
log.Fatal(err.Error())
}
ds := DataSource{
Id: 1,
Url: "http://k8s:8001",
Type: "Kubernetes",
JsonData: json,
SecureJsonData: map[string][]byte{"httpHeaderValue1": encryptedData},
}
Convey("Should match header value after decryption", func() {
headers := ds.getCustomHeaders()
So(headers["Authorization"], ShouldEqual, "Bearer xf5yhfkpsnmgo")
})
Convey("Should add header fields in HTTP Transport", func() {
// 1. Start HTTP test server which checks the request headers
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Can't use So() here, see: https://github.com/smartystreets/goconvey/issues/561
if r.Header.Get("Authorization") == "Bearer xf5yhfkpsnmgo" {
w.WriteHeader(200)
w.Write([]byte("Ok"))
return
}
w.WriteHeader(403)
w.Write([]byte("Invalid bearer token provided"))
}))
defer backend.Close()
// 2. Get HTTP transport from datasoruce which uses the test server as backend
ds.Url = backend.URL
transport, err := ds.GetHttpTransport()
if err != nil {
log.Fatal(err.Error())
}
// 3. Send test request which should have the Authorization header set
req := httptest.NewRequest("GET", backend.URL+"/test-headers", nil)
res, err := transport.RoundTrip(req)
if err != nil {
log.Fatal(err.Error())
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
log.Fatal(err.Error())
}
bodyStr := string(body)
So(bodyStr, ShouldEqual, "Ok")
}) })
}) })
} }
......
...@@ -21,11 +21,11 @@ import ( ...@@ -21,11 +21,11 @@ import (
) )
type PrometheusExecutor struct { type PrometheusExecutor struct {
Transport *http.Transport Transport http.RoundTripper
} }
type basicAuthTransport struct { type basicAuthTransport struct {
*http.Transport Transport http.RoundTripper
username string username string
password string password string
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment