Commit 99c799e9 by gotjosh Committed by GitHub

Close the connection only if we establish it. (#18897)

parent e9f1e86c
...@@ -46,6 +46,9 @@ type Server struct { ...@@ -46,6 +46,9 @@ type Server struct {
// Bind authenticates the connection with the LDAP server // Bind authenticates the connection with the LDAP server
// - with the username and password setup in the config // - with the username and password setup in the config
// - or, anonymously // - or, anonymously
//
// Dial() sets the connection with the server for this Struct. Therefore, we require a
// call to Dial() before being able to execute this function.
func (server *Server) Bind() error { func (server *Server) Bind() error {
if server.shouldAdminBind() { if server.shouldAdminBind() {
if err := server.AdminBind(); err != nil { if err := server.AdminBind(); err != nil {
...@@ -139,6 +142,8 @@ func (server *Server) Dial() error { ...@@ -139,6 +142,8 @@ func (server *Server) Dial() error {
} }
// Close closes the LDAP connection // Close closes the LDAP connection
// Dial() sets the connection with the server for this Struct. Therefore, we require a
// call to Dial() before being able to execute this function.
func (server *Server) Close() { func (server *Server) Close() {
server.Connection.Close() server.Connection.Close()
} }
...@@ -158,6 +163,9 @@ func (server *Server) Close() { ...@@ -158,6 +163,9 @@ func (server *Server) Close() {
// user without login/password binding with LDAP server, in such case // user without login/password binding with LDAP server, in such case
// we will perform "unauthenticated bind", then search for the // we will perform "unauthenticated bind", then search for the
// targeted user and then perform the bind with passed login/password. // targeted user and then perform the bind with passed login/password.
//
// Dial() sets the connection with the server for this Struct. Therefore, we require a
// call to Dial() before being able to execute this function.
func (server *Server) Login(query *models.LoginUserQuery) ( func (server *Server) Login(query *models.LoginUserQuery) (
*models.ExternalUserInfo, error, *models.ExternalUserInfo, error,
) { ) {
...@@ -231,6 +239,8 @@ func (server *Server) shouldSingleBind() bool { ...@@ -231,6 +239,8 @@ func (server *Server) shouldSingleBind() bool {
} }
// Users gets LDAP users by logins // Users gets LDAP users by logins
// Dial() sets the connection with the server for this Struct. Therefore, we require a
// call to Dial() before being able to execute this function.
func (server *Server) Users(logins []string) ( func (server *Server) Users(logins []string) (
[]*models.ExternalUserInfo, []*models.ExternalUserInfo,
error, error,
...@@ -414,6 +424,8 @@ func (server *Server) buildGrafanaUser(user *ldap.Entry) (*models.ExternalUserIn ...@@ -414,6 +424,8 @@ func (server *Server) buildGrafanaUser(user *ldap.Entry) (*models.ExternalUserIn
} }
// UserBind binds the user with the LDAP server // UserBind binds the user with the LDAP server
// Dial() sets the connection with the server for this Struct. Therefore, we require a
// call to Dial() before being able to execute this function.
func (server *Server) UserBind(username, password string) error { func (server *Server) UserBind(username, password string) error {
err := server.userBind(username, password) err := server.userBind(username, password)
if err != nil { if err != nil {
...@@ -429,6 +441,8 @@ func (server *Server) UserBind(username, password string) error { ...@@ -429,6 +441,8 @@ func (server *Server) UserBind(username, password string) error {
} }
// AdminBind binds "admin" user with LDAP // AdminBind binds "admin" user with LDAP
// Dial() sets the connection with the server for this Struct. Therefore, we require a
// call to Dial() before being able to execute this function.
func (server *Server) AdminBind() error { func (server *Server) AdminBind() error {
err := server.userBind(server.Config.BindDN, server.Config.BindPassword) err := server.userBind(server.Config.BindDN, server.Config.BindPassword)
if err != nil { if err != nil {
......
...@@ -4,10 +4,9 @@ import ( ...@@ -4,10 +4,9 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/grafana/grafana/pkg/infra/log"
. "github.com/smartystreets/goconvey/convey" . "github.com/smartystreets/goconvey/convey"
"gopkg.in/ldap.v3" "gopkg.in/ldap.v3"
"github.com/grafana/grafana/pkg/infra/log"
) )
func TestPublicAPI(t *testing.T) { func TestPublicAPI(t *testing.T) {
...@@ -22,6 +21,34 @@ func TestPublicAPI(t *testing.T) { ...@@ -22,6 +21,34 @@ func TestPublicAPI(t *testing.T) {
}) })
}) })
Convey("Close()", t, func() {
Convey("Should close the connection", func() {
connection := &MockConnection{}
server := &Server{
Config: &ServerConfig{
Attr: AttributeMap{},
SearchBaseDNs: []string{"BaseDNHere"},
},
Connection: connection,
}
So(server.Close, ShouldNotPanic)
So(connection.CloseCalled, ShouldBeTrue)
})
Convey("Should panic if no connection is established", func() {
server := &Server{
Config: &ServerConfig{
Attr: AttributeMap{},
SearchBaseDNs: []string{"BaseDNHere"},
},
Connection: nil,
}
So(server.Close, ShouldPanic)
})
})
Convey("Users()", t, func() { Convey("Users()", t, func() {
Convey("Finds one user", func() { Convey("Finds one user", func() {
MockConnection := &MockConnection{} MockConnection := &MockConnection{}
......
...@@ -19,6 +19,8 @@ type MockConnection struct { ...@@ -19,6 +19,8 @@ type MockConnection struct {
DelParams *ldap.DelRequest DelParams *ldap.DelRequest
DelCalled bool DelCalled bool
CloseCalled bool
UnauthenticatedBindCalled bool UnauthenticatedBindCalled bool
BindCalled bool BindCalled bool
...@@ -49,7 +51,9 @@ func (c *MockConnection) UnauthenticatedBind(username string) error { ...@@ -49,7 +51,9 @@ func (c *MockConnection) UnauthenticatedBind(username string) error {
} }
// Close mocks Close connection function // Close mocks Close connection function
func (c *MockConnection) Close() {} func (c *MockConnection) Close() {
c.CloseCalled = true
}
func (c *MockConnection) setSearchResult(result *ldap.SearchResult) { func (c *MockConnection) setSearchResult(result *ldap.SearchResult) {
c.SearchResult = result c.SearchResult = result
......
...@@ -85,13 +85,12 @@ func (multiples *MultiLDAP) Ping() ([]*ServerStatus, error) { ...@@ -85,13 +85,12 @@ func (multiples *MultiLDAP) Ping() ([]*ServerStatus, error) {
if err == nil { if err == nil {
status.Available = true status.Available = true
serverStatuses = append(serverStatuses, status) serverStatuses = append(serverStatuses, status)
server.Close()
} else { } else {
status.Available = false status.Available = false
status.Error = err status.Error = err
serverStatuses = append(serverStatuses, status) serverStatuses = append(serverStatuses, status)
} }
defer server.Close()
} }
return serverStatuses, nil return serverStatuses, nil
......
...@@ -40,11 +40,12 @@ func TestMultiLDAP(t *testing.T) { ...@@ -40,11 +40,12 @@ func TestMultiLDAP(t *testing.T) {
So(statuses[0].Port, ShouldEqual, 361) So(statuses[0].Port, ShouldEqual, 361)
So(statuses[0].Available, ShouldBeFalse) So(statuses[0].Available, ShouldBeFalse)
So(statuses[0].Error, ShouldEqual, expectedErr) So(statuses[0].Error, ShouldEqual, expectedErr)
So(mock.closeCalledTimes, ShouldEqual, 0)
teardown() teardown()
}) })
Convey("Shoudl get the LDAP server statuses", func() { Convey("Should get the LDAP server statuses", func() {
setup() mock := setup()
multi := New([]*ldap.ServerConfig{ multi := New([]*ldap.ServerConfig{
{Host: "10.0.0.1", Port: 361}, {Host: "10.0.0.1", Port: 361},
...@@ -57,6 +58,7 @@ func TestMultiLDAP(t *testing.T) { ...@@ -57,6 +58,7 @@ func TestMultiLDAP(t *testing.T) {
So(statuses[0].Port, ShouldEqual, 361) So(statuses[0].Port, ShouldEqual, 361)
So(statuses[0].Available, ShouldBeTrue) So(statuses[0].Available, ShouldBeTrue)
So(statuses[0].Error, ShouldBeNil) So(statuses[0].Error, ShouldBeNil)
So(mock.closeCalledTimes, ShouldEqual, 1)
teardown() teardown()
}) })
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment