Commit 8d5b0084 by Arve Knudsen Committed by GitHub

Middleware: Simplifications (#29491)

* Middleware: Simplify

Signed-off-by: Arve Knudsen <arve.knudsen@gmail.com>

* middleware: Rename auth_proxy directory to authproxy

Signed-off-by: Arve Knudsen <arve.knudsen@gmail.com>
parent 1ac02390
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require"
) )
// NewFakeStore creates store for testing // NewFakeStore creates store for testing
...@@ -26,9 +27,7 @@ func NewFakeStore(t *testing.T) *RemoteCache { ...@@ -26,9 +27,7 @@ func NewFakeStore(t *testing.T) *RemoteCache {
} }
err := dc.Init() err := dc.Init()
if err != nil { require.NoError(t, err, "Failed to init remote cache for test")
t.Fatalf("failed to init remote cache for test. error: %v", err)
}
return dc return dc
} }
...@@ -57,7 +57,7 @@ func notAuthorized(c *models.ReqContext) { ...@@ -57,7 +57,7 @@ func notAuthorized(c *models.ReqContext) {
// remove any forceLogin=true params // remove any forceLogin=true params
redirectTo = removeForceLoginParams(redirectTo) redirectTo = removeForceLoginParams(redirectTo)
WriteCookie(c.Resp, "redirect_to", url.QueryEscape(redirectTo), 0, newCookieOptions) WriteCookie(c.Resp, "redirect_to", url.QueryEscape(redirectTo), 0, nil)
c.Redirect(setting.AppSubUrl + "/login") c.Redirect(setting.AppSubUrl + "/login")
} }
......
...@@ -5,7 +5,7 @@ import ( ...@@ -5,7 +5,7 @@ import (
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/infra/remotecache"
authproxy "github.com/grafana/grafana/pkg/middleware/auth_proxy" "github.com/grafana/grafana/pkg/middleware/authproxy"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
) )
...@@ -45,7 +45,7 @@ func initContextWithAuthProxy(store *remotecache.RemoteCache, ctx *models.ReqCon ...@@ -45,7 +45,7 @@ func initContextWithAuthProxy(store *remotecache.RemoteCache, ctx *models.ReqCon
} }
// Check if allowed to continue with this IP // Check if allowed to continue with this IP
if result, err := auth.IsAllowedIP(); !result { if err := auth.IsAllowedIP(); err != nil {
logger.Error( logger.Error(
"Failed to check whitelisted IP addresses", "Failed to check whitelisted IP addresses",
"message", err.Error(), "message", err.Error(),
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/infra/remotecache"
authproxy "github.com/grafana/grafana/pkg/middleware/auth_proxy" "github.com/grafana/grafana/pkg/middleware/authproxy"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -29,7 +29,7 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) { ...@@ -29,7 +29,7 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
cmd.Result = &models.User{Id: userID} cmd.Result = &models.User{Id: userID}
return nil return nil
} }
getSignedUserHandler := func(cmd *models.GetSignedInUserQuery) error { getUserHandler := func(cmd *models.GetSignedInUserQuery) error {
// Simulate that the cached user ID is stale // Simulate that the cached user ID is stale
if cmd.UserId != userID { if cmd.UserId != userID {
return models.ErrUserNotFound return models.ErrUserNotFound
...@@ -46,7 +46,7 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) { ...@@ -46,7 +46,7 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
origEnabled := setting.AuthProxyEnabled origEnabled := setting.AuthProxyEnabled
origHeaderProperty := setting.AuthProxyHeaderProperty origHeaderProperty := setting.AuthProxyHeaderProperty
bus.AddHandler("", upsertHandler) bus.AddHandler("", upsertHandler)
bus.AddHandler("", getSignedUserHandler) bus.AddHandler("", getUserHandler)
t.Cleanup(func() { t.Cleanup(func() {
setting.AuthProxyHeaderName = origHeaderName setting.AuthProxyHeaderName = origHeaderName
setting.AuthProxyEnabled = origEnabled setting.AuthProxyEnabled = origEnabled
......
...@@ -114,11 +114,11 @@ func (auth *AuthProxy) HasHeader() bool { ...@@ -114,11 +114,11 @@ func (auth *AuthProxy) HasHeader() bool {
} }
// IsAllowedIP compares presented IP with the whitelist one // IsAllowedIP compares presented IP with the whitelist one
func (auth *AuthProxy) IsAllowedIP() (bool, *Error) { func (auth *AuthProxy) IsAllowedIP() *Error {
ip := auth.ctx.Req.RemoteAddr ip := auth.ctx.Req.RemoteAddr
if len(strings.TrimSpace(auth.whitelistIP)) == 0 { if len(strings.TrimSpace(auth.whitelistIP)) == 0 {
return true, nil return nil
} }
proxies := strings.Split(auth.whitelistIP, ",") proxies := strings.Split(auth.whitelistIP, ",")
...@@ -126,7 +126,7 @@ func (auth *AuthProxy) IsAllowedIP() (bool, *Error) { ...@@ -126,7 +126,7 @@ func (auth *AuthProxy) IsAllowedIP() (bool, *Error) {
for _, proxy := range proxies { for _, proxy := range proxies {
result, err := coerceProxyAddress(proxy) result, err := coerceProxyAddress(proxy)
if err != nil { if err != nil {
return false, newError("Could not get the network", err) return newError("could not get the network", err)
} }
proxyObjs = append(proxyObjs, result) proxyObjs = append(proxyObjs, result)
...@@ -134,13 +134,13 @@ func (auth *AuthProxy) IsAllowedIP() (bool, *Error) { ...@@ -134,13 +134,13 @@ func (auth *AuthProxy) IsAllowedIP() (bool, *Error) {
sourceIP, _, err := net.SplitHostPort(ip) sourceIP, _, err := net.SplitHostPort(ip)
if err != nil { if err != nil {
return false, newError("could not parse address", err) return newError("could not parse address", err)
} }
sourceObj := net.ParseIP(sourceIP) sourceObj := net.ParseIP(sourceIP)
for _, proxyObj := range proxyObjs { for _, proxyObj := range proxyObjs {
if proxyObj.Contains(sourceObj) { if proxyObj.Contains(sourceObj) {
return true, nil return nil
} }
} }
...@@ -148,7 +148,7 @@ func (auth *AuthProxy) IsAllowedIP() (bool, *Error) { ...@@ -148,7 +148,7 @@ func (auth *AuthProxy) IsAllowedIP() (bool, *Error) {
"request for user (%s) from %s is not from the authentication proxy", auth.header, "request for user (%s) from %s is not from the authentication proxy", auth.header,
sourceIP, sourceIP,
) )
return false, newError("Proxy authentication required", err) return newError("proxy authentication required", err)
} }
func HashCacheKey(key string) string { func HashCacheKey(key string) string {
...@@ -232,7 +232,7 @@ func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error { ...@@ -232,7 +232,7 @@ func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error {
func (auth *AuthProxy) LoginViaLDAP() (int64, *Error) { func (auth *AuthProxy) LoginViaLDAP() (int64, *Error) {
config, err := getLDAPConfig() config, err := getLDAPConfig()
if err != nil { if err != nil {
return 0, newError("Failed to get LDAP config", nil) return 0, newError("failed to get LDAP config", nil)
} }
extUser, _, err := newLDAP(config.Servers).User(auth.header) extUser, _, err := newLDAP(config.Servers).User(auth.header)
...@@ -273,7 +273,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) { ...@@ -273,7 +273,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
extUser.Email = auth.header extUser.Email = auth.header
extUser.Login = auth.header extUser.Login = auth.header
default: default:
return 0, newError("Auth proxy header property invalid", nil) return 0, newError("auth proxy header property invalid", nil)
} }
auth.headersIterator(func(field string, header string) { auth.headersIterator(func(field string, header string) {
......
...@@ -17,31 +17,31 @@ import ( ...@@ -17,31 +17,31 @@ import (
"gopkg.in/macaron.v1" "gopkg.in/macaron.v1"
) )
type TestMultiLDAP struct { type fakeMultiLDAP struct {
multildap.MultiLDAP multildap.MultiLDAP
ID int64 ID int64
userCalled bool userCalled bool
loginCalled bool loginCalled bool
} }
func (stub *TestMultiLDAP) Login(query *models.LoginUserQuery) ( func (m *fakeMultiLDAP) Login(query *models.LoginUserQuery) (
*models.ExternalUserInfo, error, *models.ExternalUserInfo, error,
) { ) {
stub.loginCalled = true m.loginCalled = true
result := &models.ExternalUserInfo{ result := &models.ExternalUserInfo{
UserId: stub.ID, UserId: m.ID,
} }
return result, nil return result, nil
} }
func (stub *TestMultiLDAP) User(login string) ( func (m *fakeMultiLDAP) User(login string) (
*models.ExternalUserInfo, *models.ExternalUserInfo,
ldap.ServerConfig, ldap.ServerConfig,
error, error,
) { ) {
stub.userCalled = true m.userCalled = true
result := &models.ExternalUserInfo{ result := &models.ExternalUserInfo{
UserId: stub.ID, UserId: m.ID,
} }
return result, ldap.ServerConfig{}, nil return result, ldap.ServerConfig{}, nil
} }
...@@ -126,7 +126,7 @@ func TestMiddlewareContext(t *testing.T) { ...@@ -126,7 +126,7 @@ func TestMiddlewareContext(t *testing.T) {
return true return true
} }
stub := &TestMultiLDAP{ stub := &fakeMultiLDAP{
ID: 42, ID: 42,
} }
...@@ -181,7 +181,7 @@ func TestMiddlewareContext(t *testing.T) { ...@@ -181,7 +181,7 @@ func TestMiddlewareContext(t *testing.T) {
auth := prepareMiddleware(t, req, store) auth := prepareMiddleware(t, req, store)
stub := &TestMultiLDAP{ stub := &fakeMultiLDAP{
ID: 42, ID: 42,
} }
......
...@@ -2,7 +2,10 @@ package middleware ...@@ -2,7 +2,10 @@ package middleware
import ( import (
"net/http" "net/http"
"net/url"
"time"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
) )
...@@ -26,14 +29,18 @@ func newCookieOptions() CookieOptions { ...@@ -26,14 +29,18 @@ func newCookieOptions() CookieOptions {
} }
} }
type GetCookieOptionsFunc func() CookieOptions type getCookieOptionsFunc func() CookieOptions
func DeleteCookie(w http.ResponseWriter, name string, getCookieOptionsFunc GetCookieOptionsFunc) { func DeleteCookie(w http.ResponseWriter, name string, getCookieOptions getCookieOptionsFunc) {
WriteCookie(w, name, "", -1, getCookieOptionsFunc) WriteCookie(w, name, "", -1, getCookieOptions)
} }
func WriteCookie(w http.ResponseWriter, name string, value string, maxAge int, getCookieOptionsFunc GetCookieOptionsFunc) { func WriteCookie(w http.ResponseWriter, name string, value string, maxAge int, getCookieOptions getCookieOptionsFunc) {
options := getCookieOptionsFunc() if getCookieOptions == nil {
getCookieOptions = newCookieOptions
}
options := getCookieOptions()
cookie := http.Cookie{ cookie := http.Cookie{
Name: name, Name: name,
MaxAge: maxAge, MaxAge: maxAge,
...@@ -47,3 +54,18 @@ func WriteCookie(w http.ResponseWriter, name string, value string, maxAge int, g ...@@ -47,3 +54,18 @@ func WriteCookie(w http.ResponseWriter, name string, value string, maxAge int, g
} }
http.SetCookie(w, &cookie) http.SetCookie(w, &cookie)
} }
func WriteSessionCookie(ctx *models.ReqContext, value string, maxLifetime time.Duration) {
if setting.Env == setting.Dev {
ctx.Logger.Info("New token", "unhashed token", value)
}
var maxAge int
if maxLifetime <= 0 {
maxAge = -1
} else {
maxAge = int(maxLifetime.Seconds())
}
WriteCookie(ctx.Resp, setting.LoginCookieName, url.QueryEscape(value), maxAge, nil)
}
...@@ -4,7 +4,6 @@ import ( ...@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
...@@ -274,21 +273,6 @@ func rotateEndOfRequestFunc(ctx *models.ReqContext, authTokenService models.User ...@@ -274,21 +273,6 @@ func rotateEndOfRequestFunc(ctx *models.ReqContext, authTokenService models.User
} }
} }
func WriteSessionCookie(ctx *models.ReqContext, value string, maxLifetime time.Duration) {
if setting.Env == setting.Dev {
ctx.Logger.Info("New token", "unhashed token", value)
}
var maxAge int
if maxLifetime <= 0 {
maxAge = -1
} else {
maxAge = int(maxLifetime.Seconds())
}
WriteCookie(ctx.Resp, setting.LoginCookieName, url.QueryEscape(value), maxAge, newCookieOptions)
}
func AddDefaultResponseHeaders() macaron.Handler { func AddDefaultResponseHeaders() macaron.Handler {
return func(ctx *macaron.Context) { return func(ctx *macaron.Context) {
ctx.Resp.Before(func(w macaron.ResponseWriter) { ctx.Resp.Before(func(w macaron.ResponseWriter) {
......
...@@ -21,7 +21,7 @@ import ( ...@@ -21,7 +21,7 @@ import (
"github.com/grafana/grafana/pkg/components/gtime" "github.com/grafana/grafana/pkg/components/gtime"
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/infra/remotecache"
authproxy "github.com/grafana/grafana/pkg/middleware/auth_proxy" "github.com/grafana/grafana/pkg/middleware/authproxy"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/login"
......
...@@ -17,7 +17,7 @@ type Descriptor struct { ...@@ -17,7 +17,7 @@ type Descriptor struct {
var services []*Descriptor var services []*Descriptor
func RegisterServiceWithPriority(instance Service, priority Priority) { func RegisterServiceWithPriority(instance Service, priority Priority) {
services = append(services, &Descriptor{ Register(&Descriptor{
Name: reflect.TypeOf(instance).Elem().Name(), Name: reflect.TypeOf(instance).Elem().Name(),
Instance: instance, Instance: instance,
InitPriority: priority, InitPriority: priority,
...@@ -25,7 +25,7 @@ func RegisterServiceWithPriority(instance Service, priority Priority) { ...@@ -25,7 +25,7 @@ func RegisterServiceWithPriority(instance Service, priority Priority) {
} }
func RegisterService(instance Service) { func RegisterService(instance Service) {
services = append(services, &Descriptor{ Register(&Descriptor{
Name: reflect.TypeOf(instance).Elem().Name(), Name: reflect.TypeOf(instance).Elem().Name(),
Instance: instance, Instance: instance,
InitPriority: Medium, InitPriority: Medium,
...@@ -33,6 +33,17 @@ func RegisterService(instance Service) { ...@@ -33,6 +33,17 @@ func RegisterService(instance Service) {
} }
func Register(descriptor *Descriptor) { func Register(descriptor *Descriptor) {
if descriptor == nil {
return
}
// Overwrite any existing equivalent service
for i, svc := range services {
if svc.Name == descriptor.Name {
services[i] = descriptor
return
}
}
services = append(services, descriptor) services = append(services, descriptor)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment