
package com.ociweb.jmx.jaas.srp;

import org.jboss.security.auth.spi.AbstractServerLoginModule;
import org.jboss.security.SimpleGroup;
import org.jboss.security.SimplePrincipal;

import javax.security.auth.login.LoginException;
import javax.security.auth.Subject;
import javax.security.auth.callback.CallbackHandler;
import javax.naming.InitialContext;
import javax.naming.NamingException;
import javax.sql.DataSource;
import java.security.Principal;
import java.security.acl.Group;
import java.util.*;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;

/**
 * @author Brian M. Coyner
 */
public class DatabaseRoleLoginModule extends AbstractServerLoginModule {

    private String dsJndiName;
    private String rolesQuery;
    private String identity;
    private Group roleGroup;

    public void initialize(Subject subject,
                           CallbackHandler callbackHandler,
                           Map sharedState,
                           Map options) {

        super.initialize(subject, callbackHandler, sharedState, options);
        this.dsJndiName = (String) options.get("dsJndiName");
        this.rolesQuery = (String) options.get("rolesQuery");
    }

    public boolean login() throws LoginException {
        this.identity = (String) sharedState.get("javax.security.auth.login.name");
        return true;
    }

    protected Principal getIdentity() {
        return new SimplePrincipal(this.identity);
    }

    protected Group[] getRoleSets() throws LoginException {
        Connection conn = null;
        PreparedStatement ps = null;
        Map groups = new HashMap();

        try {
            InitialContext context = new InitialContext();
            DataSource ds = (DataSource) context.lookup(this.dsJndiName);
            conn = ds.getConnection();

            ps = conn.prepareStatement(this.rolesQuery);
            ps.setString(1, this.identity);
            ResultSet rs = ps.executeQuery();

            this.roleGroup = new SimpleGroup("Roles");
            groups.put("Roles", roleGroup);

            while (rs.next()) {
                String name = rs.getString(1);
                String groupName = rs.getString(2);
                if (groupName == null || groupName.length() == 0) {
                    groupName = "Roles";
                }

                Group group = (Group) groups.get(groupName);

                if(group == null) {
                    group = new SimpleGroup(groupName);
                    groups.put(groupName, group);
                }

                group.addMember(new SimplePrincipal(name));
            }
            rs.close();
        } catch(NamingException ex) {
            throw new LoginException(ex.toString(true));
        } catch(SQLException ex) {
            throw new LoginException(ex.toString());
        } finally {
            try {
                ps.close();
            } catch (SQLException ignore){
            }

            try {
                conn.close();
            } catch (SQLException ignore) {
            }
        }

        Group[] roleSets = new Group[groups.size()];
        groups.values().toArray(roleSets);
        return roleSets;
    }
}