package com.house365.web.filter;

import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.house365.commons.system.FileUtils;

/**
 * XSS攻击过滤器<br>
 * 过滤所有请求，并检测是否存在XSS攻击，如果存在，则通过字符替换停止攻击<br>
 * 
 * <pre>
 * <b>
 * 使用方法:
 * 1.pom.xml中增加包引入
 * +---------------------------------------------------+
 *  &lt;dependency&gt;
 *      &lt;groupId&gt;com.house365&lt;/groupId&gt;
 *      &lt;artifactId&gt;house365-web-filter&lt;/artifactId&gt;
 *      &lt;version&gt;2.0.0-SNAPSHOT&lt;/version&gt;
 *  &lt;/dependency&gt;
 * +---------------------------------------------------+
 * 2.web工程的web.xml中配置该过滤器
 * +---------------------------------------------------+
 *  &lt;filter&gt;
 *      &lt;description&gt;XSS攻击过滤器&lt;/description&gt;
 *      &lt;display-name&gt;XSSFilter&lt;/display-name&gt;
 *      &lt;filter-name&gt;XSSFilter&lt;/filter-name&gt;
 *      &lt;filter-class&gt;com.house365.web.filter.XSSFilter&lt;/filter-class&gt;
 *      &lt;init-param&gt;
 *          &lt;param-name&gt;unFilterUrlsFile&lt;/param-name&gt;
 *          &lt;param-value&gt;properties/unFilterUrls&lt;/param-value&gt;
 *      &lt;/init-param&gt;
 *      &lt;init-param&gt;
 *          &lt;param-name&gt;changeSpecialCharacterFile&lt;/param-name&gt;
 *          &lt;param-value&gt;properties/escapeCharacter.properties&lt;/param-value&gt;
 *      &lt;/init-param&gt;
 *      &lt;init-param&gt;
 *          &lt;param-name&gt;trimToNullFlag&lt;/param-name&gt;
 *          &lt;param-value&gt;true&lt;/param-value&gt;
 *      &lt;/init-param&gt;
 *  &lt;/filter&gt;
 *  &lt;filter-mapping&gt;
 *      &lt;filter-name&gt;XSSFilter&lt;/filter-name&gt;
 *      &lt;url-pattern&gt;/*&lt;/url-pattern&gt;
 *  &lt;/filter-mapping&gt;
 * +---------------------------------------------------+
 * 3.对应的目录下放置文件
 * 3.1.properties/unFilterUrls（不做过滤的请求）举例
 *      ----------------------------------------------
 *      /cas
 *      /login
 *      /logout
 *      ----------------------------------------------
 * 3.2.properties/escapeCharacter.properties（特殊字符、字符串替换规则）举例
 *      ----------------------------------------------
 *      &amp;amp; = &amp;amp;amp;
 *      &amp;nbsp = &amp;amp;nbsp
 *      &amp;nbsp; = &amp;amp;nbsp;
 *      &amp;quote; = &amp;amp;quote;
 *      &lt; = &amp;lt;
 *      &gt; = &amp;gt;
 *      ----------------------------------------------
 * +---------------------------------------------------+
 * </b>
 * </pre>
 * 
 * @author duhui
  *@version 2.0.0, 2015年01月12日
 */
public class XSSFilter implements Filter {
    /**
     * 日志记录器
     */
    private static final Logger LOGGER = LoggerFactory.getLogger(XSSFilter.class);

    /**
     * 不需要过滤的URL
     */
    private Set<String> unFilterUrls = null;

    /**
     * 是否替换特殊字符
     */
    private boolean trimToNull = true;

    /**
     * 特殊字符替换规则
     */
    private Map<String, String> escapeCharacters = new LinkedHashMap<String, String>();

    /**
     * Default constructor.
     */
    public XSSFilter() {
    }

    public void destroy() {
    }

    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException,
            ServletException {
        request.setAttribute("requestself", (HttpServletRequest) request);
        if (unfilter((HttpServletRequest) request)) {
            chain.doFilter(request, response);
        } else {
            chain.doFilter(new XssHttpServletRequestWrapper(request, this.trimToNull, this.escapeCharacters), response);
        }
    }

    /**
     * @param filterConfig 过滤器配置信息
     */
    public void init(FilterConfig filterConfig) throws ServletException {
        String unFilterUrlsFile = filterConfig.getInitParameter("unFilterUrlsFile");
        String changeSpecialCharacterFile = filterConfig.getInitParameter("changeSpecialCharacterFile");
        String trimToNullFlag = filterConfig.getInitParameter("trimToNullFlag");
        if (StringUtils.isNotBlank(trimToNullFlag)) {
            trimToNull = Boolean.parseBoolean(trimToNullFlag);
        }
        if (StringUtils.isNotBlank(changeSpecialCharacterFile)) {
            try {
                List<String> changeSpecialCharacterList = FileUtils.readFileByLine(XSSFilter.class.getClassLoader()
                        .getResource(changeSpecialCharacterFile).getFile());
                for (String changeSpecialCharacter : changeSpecialCharacterList) {
                    int index = changeSpecialCharacter.indexOf('=');
                    if (-1 != index) {
                        String key = changeSpecialCharacter.substring(0, index).trim();
                        String value = changeSpecialCharacter.substring(index + 1).trim();
                        escapeCharacters.put(key, value);
                    }
                }
                // 设置escapeCharachters只读
                escapeCharacters = Collections.unmodifiableMap(escapeCharacters);
            } catch (Throwable e) {
                LOGGER.error("无法加载特殊字符转化列表", e);
            }
        }

        if (StringUtils.isNotBlank(unFilterUrlsFile)) {
            List<String> unfilterUrls;
            try {
                unfilterUrls = FileUtils.readFileByLine(XSSFilter.class.getClassLoader().getResource(unFilterUrlsFile)
                        .getFile());
                unFilterUrls = new HashSet<String>(unfilterUrls);
            } catch (Throwable e) {
                LOGGER.error("无法加载不过滤请求白名单列表", e);
            }
        }
    }

    /**
     * 
     * 功能描述: 判断请求的URI是否在无需过滤列表的白名单中<br>
     * 
     * @author duhui
     * @version 2.0.0, 2015年01月12日
     * @param request 检测的请求
     * @return 不需要过滤<li>true--不需要过滤<li>false--需要过滤
     */
    private boolean unfilter(HttpServletRequest request) {
        if (CollectionUtils.isEmpty(unFilterUrls)) {
            return false;
        }
        // 获取完整请求路径
        String uri = request.getRequestURI();
        // 获取上下文路径
        String contextPath = request.getContextPath();
        // 将完整请求路径中上下路径去掉(留下的为相对路径)
        String passed = uri.replaceFirst(contextPath, "");
        return unFilterUrls.contains(passed);
    }
}
