LIFULL Creators Blog

LIFULL Creators Blogとは、株式会社LIFULLの社員が記事を共有するブログです。自分の役立つ経験や知識を広めることで世界をもっとFULLにしていきます。

solr で独自基準ソート(search component plugin 後編)

古川です。

search component plugin 後編です。 search component plugin 前編で作成した MyQueryComponent.java にスコア計算をするための処理を追加していきます。

Collector の作成

以前の記事で紹介しましたが、luceneのcollectorクラスを使うと、ソートのためのスコア計算を柔軟に定義することができます。

そこで、フィールドx、フィールドy の値を使って、

score = a*x*x + b*x*y + c*y*y + d*x + e*y + f

というスコア値を計算する MyCollectorクラスを作成します。

MyCollector.java

package jp.co.homes.searchcomponent;

import java.io.IOException;
import java.util.Collections;
import java.util.ArrayList;
import java.util.Comparator;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.FieldCache;
import org.apache.lucene.util.PriorityQueue;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.AtomicReaderContext;

public class MyCollector extends Collector {
    private static final String FIELD_X = "x";
    private static final String FIELD_Y = "y";
    
    private int docBase;
    private int totalHits;
    private FieldCache.Floats x;
    private FieldCache.Floats y;
    private HitQueue pq;
    private ScoreDoc pqTop;
    private float maxScore;
    private float[] coefficients;

    public MyCollector(float[] _coefficients, int maxLen) {
        this.pq = new HitQueue(maxLen, true);
        this.pqTop = this.pq.top();
        this.maxScore = Float.NEGATIVE_INFINITY;
        this.coefficients = _coefficients;
    }

    @Override
    public void setNextReader(AtomicReaderContext context) throws IOException {
        AtomicReader reader = context.reader();
        this.docBase = context.docBase;
        this.x = FieldCache.DEFAULT.getFloats(reader, FIELD_X, false);
        this.y = FieldCache.DEFAULT.getFloats(reader, FIELD_Y, false);
    }

    @Override
    public void setScorer(Scorer socorer){}

    @Override
    public boolean acceptsDocsOutOfOrder(){
        return true;
    }

    @Override
    public void collect(int doc) {
        totalHits++;
        float x = (float)this.x.get(doc);
        float y = (float)this.y.get(doc);
        float score = this.coefficients[0]*x*x;
        score += this.coefficients[1]*x*y;
        score += this.coefficients[2]*y*y;
        score += this.coefficients[3]*x;
        score += this.coefficients[4]*y;
        score += this.coefficients[5];

        if (score < pqTop.score) {
            // 現在の最低点より小さいスコアの
            // ドキュメントは無視
            return;
        }
        if (this.maxScore < score) {
            this.maxScore = score;
        }
        
        int docid = this.docBase + doc;
        // 今先頭にあるスコアを、新しい値に交換
        this.pqTop.doc = docid;
        this.pqTop.score = score;
        
        // 新しい値のスコアを登録しHitQueueを再構築
        // 一番小さい値を持つオブジェクトがキューの先頭になる
        pqTop = pq.updateTop();
    }

    public int getTotalHits() {
        return this.totalHits;
    }

    public float getMaxScore() {
        return this.maxScore;
    }
    
    public ScoreDoc[] getResults() {
        int queueSize = this.pq.size();
        ScoreDoc[] results = new ScoreDoc[queueSize];
        for (int i=0; i<queueSize; i++) {
            results[i] = this.pq.pop();
        }
        return results;
    }
    
    // プライオリティーキュー
    class HitQueue extends PriorityQueue<ScoreDoc> {
        HitQueue(int size, boolean prePopulate) {
            super(size, prePopulate);
        }

        @Override
        public ScoreDoc getSentinelObject() {
            return new ScoreDoc(Integer.MAX_VALUE, Float.NEGATIVE_INFINITY);
        }

        @Override
        public final boolean lessThan(ScoreDoc hitA, ScoreDoc hitB) {
            return hitA.score < hitB.score;
       }
    }
}

Search Component Plugin 実装

前回作成した、MyQueryComponent.java を、このMyCollectorクラスを使ってスコア計算するよう変更します。

MyQueryComponent.java

package jp.co.homes.searchcomponent;

import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Filter;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ConstantScoreQuery;

import org.apache.solr.client.solrj.SolrServerException;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.*;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.response.ResultContext;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.schema.FieldType;
import org.apache.solr.search.QParser;
import org.apache.solr.search.QParserPlugin;
import org.apache.solr.search.QueryParsing;
import org.apache.solr.search.ReturnFields;
import org.apache.solr.search.DocSet;
import org.apache.solr.search.DocList;
import org.apache.solr.search.DocSlice;
import org.apache.solr.search.SolrIndexSearcher;
import org.apache.solr.search.SolrReturnFields;
import org.apache.solr.search.SyntaxError;

import org.apache.solr.handler.component.SearchComponent;
import org.apache.solr.handler.component.ResponseBuilder;

import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;

public class MyQueryComponent extends SearchComponent {
  public static final String COMPONENT_NAME = "my_query";

  @Override
  public void prepare(ResponseBuilder rb) throws IOException {
    SolrQueryRequest req = rb.req;
    SolrParams params = req.getParams();
    if (!params.getBool(COMPONENT_NAME, true)) {
      return;
    }
    SolrQueryResponse rsp = rb.rsp;

    ReturnFields returnFields = new SolrReturnFields( req );
    rsp.setReturnFields( returnFields );
    int flags = 0;
    if (returnFields.wantsScore()) {
      flags |= SolrIndexSearcher.GET_SCORES;
    }
    rb.setFieldFlags( flags );

    String defType = params.get(QueryParsing.DEFTYPE, QParserPlugin.DEFAULT_QTYPE);

    String queryString = rb.getQueryString();
    if (queryString == null) {
      queryString = params.get( CommonParams.Q );
      rb.setQueryString(queryString);
    }

    try {
      QParser parser = QParser.getParser(rb.getQueryString(), defType, req);
      Query q = parser.getQuery();
      if (q == null) {
        q = new BooleanQuery();
      }
      rb.setQuery( q );
      rb.setSortSpec( parser.getSort(true) );
      rb.setQparser(parser);
      rb.setScoreDoc(parser.getPaging());

      String[] fqs = req.getParams().getParams(CommonParams.FQ);
      if (fqs!=null && fqs.length!=0) {
        List<Query> filters = rb.getFilters();
        filters = filters == null ? new ArrayList<Query>(fqs.length) : new ArrayList<Query>(filters);
        for (String fq : fqs) {
          if (fq != null && fq.trim().length()!=0) {
            QParser fqp = QParser.getParser(fq, null, req);
            filters.add(fqp.getQuery());
          }
        }

        if (!filters.isEmpty()) {
          rb.setFilters( filters );
        }
      }
    } catch (SyntaxError e) {
      throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
    }
  }

  @Override
  public void process(ResponseBuilder rb) throws IOException {
    SolrQueryRequest req = rb.req;
    SolrQueryResponse rsp = rb.rsp;
    SolrParams params = req.getParams();
    if (!params.getBool(COMPONENT_NAME, true)) {
      return;
    }
    SolrIndexSearcher searcher = req.getSearcher();
    SolrIndexSearcher.QueryCommand cmd = rb.getQueryCommand();

    // スコア計算なしで、ヒットするドキュメントID一覧を取得
    DocSet ds = null;
    if (cmd.getFilterList() == null) {
        ds = searcher.getDocSet(cmd.getQuery());
    } else {
        List<Query> newList = new ArrayList<Query>(cmd.getFilterList().size() +1);
        newList.add(cmd.getQuery());
        newList.addAll(cmd.getFilterList());
        ds = searcher.getDocSet(newList);
    }

    // ヒット件数を取得
    DocList dl = null;
    int numHits = ds.size();
    int requestOffset = cmd.getOffset();
    int requestLen = cmd.getLen();

    if (numHits < requestOffset) {
        dl = new DocSlice(0, 0, new int[0], null, numHits, 0.0f);
    } else {
        int queueSize = requestOffset + requestLen;
        if (numHits < queueSize) {
            queueSize = numHits;
        }
        // result セットされている値を  docset をluceneFilter
        // 独自スコア計算用のcollector として、
        // 独自collectorを使ってスコア計算
        Query luceneQuery = new ConstantScoreQuery(ds.getTopFilter());
        float[] coefficients = parseMyParmas(params);

        // 必要なサイズだけ
        MyCollector collector = new MyCollector(coefficients, queueSize);
        searcher.search(luceneQuery, null, collector);
        
        // 結果を取得
        ScoreDoc[] docs = collector.getResults();
        float maxScore = collector.getMaxScore();

        // レスポンス用に加工
        int[] ids = new int[queueSize];
        float[] scores = new float[queueSize];
        
        for (int i=0; i<queueSize; i++) {
            ScoreDoc sc = docs[i];
            // 小さい順に取り出してるので大きい順に並べ替え
            int index = queueSize - 1 - i;
            ids[index] = sc.doc;
            scores[index] = sc.score;
        }
        dl = new DocSlice(0, queueSize, ids, scores, numHits, maxScore);
        dl = dl.subset(requestOffset, queueSize - requestOffset);
    }
    ResultContext ctx = new ResultContext();
    ctx.docs = dl;
    ctx.query = rb.getQuery();
    rsp.add("response", ctx);
    rsp.getToLog().add("hits", numHits);
  }

  private float[] parseMyParmas(SolrParams params) {
      float[] coefficients = new float[6];
      String buf = params.get("myparams");
      String[] myparams = buf.split(",");
      for(int i=0; i<myparams.length; i++) {
          coefficients[i] = Float.parseFloat(myparams[i]);
      }
      return coefficients;
  }

  @Override
  public String getDescription() {
    return "my query component";
  }

  @Override
  public String getSource() {
    return "$URL: dummy $";
  }
}

コンパイル・設定

前回同様なので省略します。

動作確認

solr を起動して、前回作成したmyfunc を使ったクエリと、今回作成したmy_select の検索結果が同じであることを確認します。

myfunc(function query plugin)

http://localhost:8983/solr/collection1/select?echoParams=none&q=*:*&fq=x:[0 TO 500]&fl=id,x,y,myfunc(x,y,1,2,3,4,5,6)&sort=myfunc(x,y,1,2,3,4,5,6) desc&rows=3
<result name="response" numFound="500773" start="0">
  <doc>
    <str name="id">id789132</str>
    <float name="x">500.0</float>
    <float name="y">999.0</float>
    <float name="myfunc(x,y,1,2,3,4,5,6)">4250004.0</float>
  </doc>
  <doc>
    <str name="id">id96466</str>
    <float name="x">499.0</float>
    <float name="y">999.0</float>
    <float name="myfunc(x,y,1,2,3,4,5,6)">4247003.0</float>
  </doc>
  <doc>
    <str name="id">id882478</str>
    <float name="x">499.0</float>
    <float name="y">999.0</float>
    <float name="myfunc(x,y,1,2,3,4,5,6)">4247003.0</float>
  </doc>
</result>

my_select(search component plugin)

http://localhost:8983/solr/collection1/my_select?echoParams=none&q=*:*&fq=x:[0 TO 500]&fl=id,x,y,score&myparams=1,2,3,4,5,6&rows=3
<result name="response" numFound="500773" start="0" maxScore="4250004.0">
  <doc>
    <str name="id">id789132</str>
    <float name="x">500.0</float>
    <float name="y">999.0</float>
    <float name="score">4250004.0</float>
  </doc>
  <doc>
    <str name="id">id96466</str>
    <float name="x">499.0</float>
    <float name="y">999.0</float>
    <float name="score">4247003.0</float>
  </doc>
  <doc>
    <str name="id">id882478</str>
    <float name="x">499.0</float>
    <float name="y">999.0</float>
    <float name="score">4247003.0</float>
  </doc>
</result>

scoreが同じ場合のソート順が微妙に違うため、ときどき順序が前後しますが、score値は同じで、どちらも大きい順に並んでいることが分かります。

速度評価

スコア計算に関連する引数を変更しながら、5回クエリを実行してQTimeの平均値を計算してみました。

  • myfunc(function query plugin) 15.4ms
  • my_select(search component plugin) 11.6ms

search component plugin でも高速なスコア計算が実現できることが確認できました。

まとめ

search component plugin を使った独自ソートの実現方法を紹介しました。実際に使う場合には、きちんとしたエラー処理や SolrIndexSearcher.java の中で行われているようなクエリキャッシュの仕組みを追加していく必要がありますが、大体こんな感じでいけそうです。

function query plugin で実現できないソート用スコア計算が必要になった時は、search component plugin を使った方法を検討してみても良いかもしれません。