Commit 318182cc by Oleg Gaidarenko Committed by Torkel Ödegaard

Chore: refactor auth proxy (#16504)

* Chore: refactor auth proxy

Introduced the helper struct for auth_proxy middleware.
Added couple unit-tests, but it seems "integration" tests already cover
most of the code paths.

Although it might be good idea to test every bit of it, hm.
Haven't refactored the extraction of the header logic that much

Fixes #16147

* Fix: make linters happy
parent 8069a617
package middleware
import (
"fmt"
"net"
"net/mail"
"reflect"
"strings"
"time"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/login"
authproxy "github.com/grafana/grafana/pkg/middleware/auth_proxy"
m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
)
const (
// cachePrefix is a prefix for the cache key
cachePrefix = "auth-proxy-sync-ttl:%s"
cachePrefix = authproxy.CachePrefix
)
func initContextWithAuthProxy(store *remotecache.RemoteCache, ctx *m.ReqContext, orgID int64) bool {
if !setting.AuthProxyEnabled {
auth := authproxy.New(&authproxy.Options{
Store: store,
Ctx: ctx,
OrgID: orgID,
})
// Bail if auth proxy is not enabled
if auth.IsEnabled() == false {
return false
}
proxyHeaderValue := ctx.Req.Header.Get(setting.AuthProxyHeaderName)
if len(proxyHeaderValue) == 0 {
// If the there is no header - we can't move forward
if auth.HasHeader() == false {
return false
}
// if auth proxy ip(s) defined, check if request comes from one of those
if err := checkAuthenticationProxy(ctx.Req.RemoteAddr, proxyHeaderValue); err != nil {
ctx.Handle(407, "Proxy authentication required", err)
// Check if allowed to continue with this IP
if result, err := auth.IsAllowedIP(); result == false {
ctx.Handle(407, err.Error(), err.DetailsError)
return true
}
query := &m.GetSignedInUserQuery{OrgId: orgID}
cacheKey := fmt.Sprintf(cachePrefix, proxyHeaderValue)
userID, err := store.Get(cacheKey)
inCache := err == nil
// load the user if we have them
if inCache {
query.UserId = userID.(int64)
// if we're using ldap, pass authproxy login name to ldap user sync
} else if setting.LdapEnabled {
syncQuery := &m.LoginUserQuery{
ReqContext: ctx,
Username: proxyHeaderValue,
}
if err := syncGrafanaUserWithLdapUser(syncQuery); err != nil {
if err == login.ErrInvalidCredentials {
ctx.Handle(500, "Unable to authenticate user", err)
return false
}
}
if syncQuery.User == nil {
ctx.Handle(500, "Failed to sync user", nil)
return false
}
query.UserId = syncQuery.User.Id
// no ldap, just use the info we have
} else {
extUser := &m.ExternalUserInfo{
AuthModule: "authproxy",
AuthId: proxyHeaderValue,
}
if setting.AuthProxyHeaderProperty == "username" {
extUser.Login = proxyHeaderValue
// only set Email if it can be parsed as an email address
emailAddr, emailErr := mail.ParseAddress(proxyHeaderValue)
if emailErr == nil {
extUser.Email = emailAddr.Address
}
} else if setting.AuthProxyHeaderProperty == "email" {
extUser.Email = proxyHeaderValue
extUser.Login = proxyHeaderValue
} else {
ctx.Handle(500, "Auth proxy header property invalid", nil)
return true
}
for _, field := range []string{"Name", "Email", "Login"} {
if setting.AuthProxyHeaders[field] == "" {
continue
}
if val := ctx.Req.Header.Get(setting.AuthProxyHeaders[field]); val != "" {
reflect.ValueOf(extUser).Elem().FieldByName(field).SetString(val)
}
}
// add/update user in grafana
cmd := &m.UpsertUserCommand{
ReqContext: ctx,
ExternalUser: extUser,
SignupAllowed: setting.AuthProxyAutoSignUp,
}
err := bus.Dispatch(cmd)
if err != nil {
ctx.Handle(500, "Failed to login as user specified in auth proxy header", err)
return true
}
query.UserId = cmd.Result.Id
// Try to get user id from various sources
id, err := auth.GetUserID()
if err != nil {
ctx.Handle(500, err.Error(), err.DetailsError)
return true
}
if err := bus.Dispatch(query); err != nil {
ctx.Handle(500, "Failed to find user", err)
// Get full user info
user, err := auth.GetSignedUser(id)
if err != nil {
ctx.Handle(500, err.Error(), err.DetailsError)
return true
}
ctx.SignedInUser = query.Result
ctx.IsSignedIn = true
expiration := time.Duration(-setting.AuthProxyLdapSyncTtl) * time.Minute
value := query.UserId
// Add user info to context
ctx.SignedInUser = user
ctx.IsSignedIn = true
// This <if> is here to make sure we do not
// rewrite the expiration all the time
if inCache == false {
if err = store.Set(cacheKey, value, expiration); err != nil {
ctx.Handle(500, "Couldn't write a user in cache key", err)
return true
}
// Remember user data it in cache
if err := auth.Remember(); err != nil {
ctx.Handle(500, err.Error(), err.DetailsError)
return true
}
return true
}
var syncGrafanaUserWithLdapUser = func(query *m.LoginUserQuery) error {
ldapCfg := login.LdapCfg
if len(ldapCfg.Servers) < 1 {
return fmt.Errorf("No LDAP servers available")
}
for _, server := range ldapCfg.Servers {
author := login.NewLdapAuthenticator(server)
if err := author.SyncUser(query); err != nil {
return err
}
}
return nil
}
func checkAuthenticationProxy(remoteAddr string, proxyHeaderValue string) error {
if len(strings.TrimSpace(setting.AuthProxyWhitelist)) == 0 {
return nil
}
proxies := strings.Split(setting.AuthProxyWhitelist, ",")
var proxyObjs []*net.IPNet
for _, proxy := range proxies {
proxyObjs = append(proxyObjs, coerceProxyAddress(proxy))
}
sourceIP, _, _ := net.SplitHostPort(remoteAddr)
sourceObj := net.ParseIP(sourceIP)
for _, proxyObj := range proxyObjs {
if proxyObj.Contains(sourceObj) {
return nil
}
}
return fmt.Errorf("Request for user (%s) from %s is not from the authentication proxy", proxyHeaderValue, sourceIP)
}
func coerceProxyAddress(proxyAddr string) *net.IPNet {
proxyAddr = strings.TrimSpace(proxyAddr)
if !strings.Contains(proxyAddr, "/") {
proxyAddr = strings.Join([]string{proxyAddr, "32"}, "/")
}
_, network, err := net.ParseCIDR(proxyAddr)
if err != nil {
fmt.Println(err)
}
return network
}
package authproxy
import (
"fmt"
"net"
"net/mail"
"reflect"
"strings"
"time"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/login"
models "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
)
const (
// CachePrefix is a prefix for the cache key
CachePrefix = "auth-proxy-sync-ttl:%s"
)
// AuthProxy struct
type AuthProxy struct {
store *remotecache.RemoteCache
ctx *models.ReqContext
orgID int64
header string
LDAP func(server *login.LdapServerConf) login.ILdapAuther
enabled bool
whitelistIP string
headerType string
headers map[string]string
cacheTTL int
ldapEnabled bool
}
// Error auth proxy specific error
type Error struct {
Message string
DetailsError error
}
// newError creates the Error
func newError(message string, err error) *Error {
return &Error{
Message: message,
DetailsError: err,
}
}
// Error returns a Error error string
func (err *Error) Error() string {
return fmt.Sprintf("%s", err.Message)
}
// Options for the AuthProxy
type Options struct {
Store *remotecache.RemoteCache
Ctx *models.ReqContext
OrgID int64
}
// New instance of the AuthProxy
func New(options *Options) *AuthProxy {
header := options.Ctx.Req.Header.Get(setting.AuthProxyHeaderName)
return &AuthProxy{
store: options.Store,
ctx: options.Ctx,
orgID: options.OrgID,
header: header,
LDAP: login.NewLdapAuthenticator,
enabled: setting.AuthProxyEnabled,
headerType: setting.AuthProxyHeaderProperty,
headers: setting.AuthProxyHeaders,
whitelistIP: setting.AuthProxyWhitelist,
cacheTTL: setting.AuthProxyLdapSyncTtl,
ldapEnabled: setting.LdapEnabled,
}
}
// IsEnabled checks if the proxy auth is enabled
func (auth *AuthProxy) IsEnabled() bool {
// Bail if the setting is not enabled
if auth.enabled == false {
return false
}
return true
}
// HasHeader checks if the we have specified header
func (auth *AuthProxy) HasHeader() bool {
if len(auth.header) == 0 {
return false
}
return true
}
// IsAllowedIP compares presented IP with the whitelist one
func (auth *AuthProxy) IsAllowedIP() (bool, *Error) {
ip := auth.ctx.Req.RemoteAddr
if len(strings.TrimSpace(auth.whitelistIP)) == 0 {
return true, nil
}
proxies := strings.Split(auth.whitelistIP, ",")
var proxyObjs []*net.IPNet
for _, proxy := range proxies {
result, err := coerceProxyAddress(proxy)
if err != nil {
return false, newError("Could not get the network", err)
}
proxyObjs = append(proxyObjs, result)
}
sourceIP, _, _ := net.SplitHostPort(ip)
sourceObj := net.ParseIP(sourceIP)
for _, proxyObj := range proxyObjs {
if proxyObj.Contains(sourceObj) {
return true, nil
}
}
err := fmt.Errorf(
"Request for user (%s) from %s is not from the authentication proxy", auth.header,
sourceIP,
)
return false, newError("Proxy authentication required", err)
}
// InCache checks if we have user in cache
func (auth *AuthProxy) InCache() bool {
userID, _ := auth.GetUserIDViaCache()
if userID == 0 {
return false
}
return true
}
// getKey forms a key for the cache
func (auth *AuthProxy) getKey() string {
return fmt.Sprintf(CachePrefix, auth.header)
}
// GetUserID gets user id with whatever means possible
func (auth *AuthProxy) GetUserID() (int64, *Error) {
if auth.InCache() {
// Error here means absent cache - we don't need to handle that
id, _ := auth.GetUserIDViaCache()
return id, nil
}
if auth.ldapEnabled {
id, err := auth.GetUserIDViaLDAP()
if err == login.ErrInvalidCredentials {
return 0, newError("Proxy authentication required", login.ErrInvalidCredentials)
}
if err != nil {
return 0, newError("Failed to sync user", err)
}
return id, nil
}
id, err := auth.GetUserIDViaHeader()
if err != nil {
return 0, newError("Failed to login as user specified in auth proxy header", err)
}
return id, nil
}
// GetUserIDViaCache gets the user from cache
func (auth *AuthProxy) GetUserIDViaCache() (int64, error) {
var (
cacheKey = auth.getKey()
userID, err = auth.store.Get(cacheKey)
)
if err != nil {
return 0, err
}
return userID.(int64), nil
}
// GetUserIDViaLDAP gets user via LDAP request
func (auth *AuthProxy) GetUserIDViaLDAP() (int64, *Error) {
query := &models.LoginUserQuery{
ReqContext: auth.ctx,
Username: auth.header,
}
ldapCfg := login.LdapCfg
if len(ldapCfg.Servers) < 1 {
return 0, newError("No LDAP servers available", nil)
}
for _, server := range ldapCfg.Servers {
author := auth.LDAP(server)
if err := author.SyncUser(query); err != nil {
return 0, newError(err.Error(), nil)
}
}
return query.User.Id, nil
}
// GetUserIDViaHeader gets user from the header only
func (auth *AuthProxy) GetUserIDViaHeader() (int64, error) {
extUser := &models.ExternalUserInfo{
AuthModule: "authproxy",
AuthId: auth.header,
}
if auth.headerType == "username" {
extUser.Login = auth.header
// only set Email if it can be parsed as an email address
emailAddr, emailErr := mail.ParseAddress(auth.header)
if emailErr == nil {
extUser.Email = emailAddr.Address
}
} else if auth.headerType == "email" {
extUser.Email = auth.header
extUser.Login = auth.header
} else {
return 0, newError("Auth proxy header property invalid", nil)
}
for _, field := range []string{"Name", "Email", "Login"} {
if auth.headers[field] == "" {
continue
}
if val := auth.ctx.Req.Header.Get(auth.headers[field]); val != "" {
reflect.ValueOf(extUser).Elem().FieldByName(field).SetString(val)
}
}
// add/update user in grafana
cmd := &models.UpsertUserCommand{
ReqContext: auth.ctx,
ExternalUser: extUser,
SignupAllowed: setting.AuthProxyAutoSignUp,
}
err := bus.Dispatch(cmd)
if err != nil {
return 0, err
}
return cmd.Result.Id, nil
}
// GetSignedUser get full signed user info
func (auth *AuthProxy) GetSignedUser(userID int64) (*models.SignedInUser, *Error) {
query := &models.GetSignedInUserQuery{
OrgId: auth.orgID,
UserId: userID,
}
if err := bus.Dispatch(query); err != nil {
return nil, newError(err.Error(), nil)
}
return query.Result, nil
}
// Remember user in cache
func (auth *AuthProxy) Remember() *Error {
// Make sure we do not rewrite the expiration time
if auth.InCache() {
return nil
}
var (
key = auth.getKey()
value, _ = auth.GetUserIDViaCache()
expiration = time.Duration(-auth.cacheTTL) * time.Minute
err = auth.store.Set(key, value, expiration)
)
if err != nil {
return newError(err.Error(), nil)
}
return nil
}
// coerceProxyAddress gets network of the presented CIDR notation
func coerceProxyAddress(proxyAddr string) (*net.IPNet, error) {
proxyAddr = strings.TrimSpace(proxyAddr)
if !strings.Contains(proxyAddr, "/") {
proxyAddr = strings.Join([]string{proxyAddr, "32"}, "/")
}
_, network, err := net.ParseCIDR(proxyAddr)
return network, err
}
package authproxy
import (
"fmt"
"net/http"
"testing"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/login"
models "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
. "github.com/smartystreets/goconvey/convey"
"gopkg.in/macaron.v1"
)
type TestLDAP struct {
login.ILdapAuther
ID int64
syncCalled bool
}
func (stub *TestLDAP) SyncUser(query *models.LoginUserQuery) error {
stub.syncCalled = true
query.User = &models.User{
Id: stub.ID,
}
return nil
}
func TestMiddlewareContext(t *testing.T) {
Convey("auth_proxy helper", t, func() {
req, _ := http.NewRequest("POST", "http://example.com", nil)
setting.AuthProxyHeaderName = "X-Killa"
name := "markelog"
req.Header.Add(setting.AuthProxyHeaderName, name)
ctx := &models.ReqContext{
Context: &macaron.Context{
Req: macaron.Request{
Request: req,
},
},
}
Convey("gets data from the cache", func() {
store := remotecache.NewFakeStore(t)
key := fmt.Sprintf(CachePrefix, name)
store.Set(key, int64(33), 0)
auth := New(&Options{
Store: store,
Ctx: ctx,
OrgID: 4,
})
id, err := auth.GetUserID()
So(err, ShouldBeNil)
So(id, ShouldEqual, 33)
})
Convey("LDAP", func() {
Convey("gets data from the LDAP", func() {
login.LdapCfg = login.LdapConfig{
Servers: []*login.LdapServerConf{
{},
},
}
setting.LdapEnabled = true
store := remotecache.NewFakeStore(t)
auth := New(&Options{
Store: store,
Ctx: ctx,
OrgID: 4,
})
stub := &TestLDAP{
ID: 42,
}
auth.LDAP = func(server *login.LdapServerConf) login.ILdapAuther {
return stub
}
id, err := auth.GetUserID()
So(err, ShouldBeNil)
So(id, ShouldEqual, 42)
So(stub.syncCalled, ShouldEqual, true)
})
Convey("gets nice error if ldap is enabled but not configured", func() {
setting.LdapEnabled = false
store := remotecache.NewFakeStore(t)
auth := New(&Options{
Store: store,
Ctx: ctx,
OrgID: 4,
})
stub := &TestLDAP{
ID: 42,
}
auth.LDAP = func(server *login.LdapServerConf) login.ILdapAuther {
return stub
}
id, err := auth.GetUserID()
So(err, ShouldNotBeNil)
So(id, ShouldNotEqual, 42)
So(stub.syncCalled, ShouldEqual, false)
})
})
})
}
......@@ -276,52 +276,9 @@ func TestMiddlewareContext(t *testing.T) {
setting.AuthProxyHeaderProperty = "username"
name := "markelog"
middlewareScenario(t, "should sync the user if it's not in the cache", func(sc *scenarioContext) {
called := false
syncGrafanaUserWithLdapUser = func(query *m.LoginUserQuery) error {
called = true
query.User = &m.User{Id: 32}
return nil
}
bus.AddHandler("test", func(query *m.UpsertUserCommand) error {
query.Result = &m.User{Id: 32}
return nil
})
bus.AddHandler("test", func(query *m.GetSignedInUserQuery) error {
query.Result = &m.SignedInUser{OrgId: 4, UserId: 32}
return nil
})
sc.fakeReq("GET", "/")
sc.req.Header.Add(setting.AuthProxyHeaderName, name)
sc.exec()
Convey("Should init user via ldap", func() {
So(called, ShouldBeTrue)
So(sc.context.IsSignedIn, ShouldBeTrue)
So(sc.context.UserId, ShouldEqual, 32)
So(sc.context.OrgId, ShouldEqual, 4)
})
})
middlewareScenario(t, "should not sync the user if it's in the cache", func(sc *scenarioContext) {
called := false
syncGrafanaUserWithLdapUser = func(query *m.LoginUserQuery) error {
called = true
query.User = &m.User{Id: 32}
return nil
}
bus.AddHandler("test", func(query *m.UpsertUserCommand) error {
query.Result = &m.User{Id: 32}
return nil
})
bus.AddHandler("test", func(query *m.GetSignedInUserQuery) error {
query.Result = &m.SignedInUser{OrgId: 4, UserId: 32}
query.Result = &m.SignedInUser{OrgId: 4, UserId: query.UserId}
return nil
})
......@@ -332,17 +289,10 @@ func TestMiddlewareContext(t *testing.T) {
sc.req.Header.Add(setting.AuthProxyHeaderName, name)
sc.exec()
cacheValue, cacheErr := sc.remoteCacheService.Get(key)
Convey("Should init user via cache", func() {
So(called, ShouldBeFalse)
So(sc.context.IsSignedIn, ShouldBeTrue)
So(sc.context.UserId, ShouldEqual, 32)
So(sc.context.UserId, ShouldEqual, 33)
So(sc.context.OrgId, ShouldEqual, 4)
So(cacheValue, ShouldEqual, 33)
So(cacheErr, ShouldBeNil)
})
})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment