macavaney commited on
Commit
6b485fc
1 Parent(s): b38a07c
Files changed (1) hide show
  1. app.py +11 -19
app.py CHANGED
@@ -7,13 +7,13 @@ import pyterrier as pt
7
  pt.init()
8
  import pyt_splade
9
  from pyterrier_gradio import Demo, MarkdownFile, interface, df2code, code2md, EX_Q, EX_D
10
- factory_max = pyt_splade.SpladeFactory(agg='max')
11
- factory_sum = pyt_splade.SpladeFactory(agg='sum')
12
 
13
  COLAB_NAME = 'pyterrier_splade.ipynb'
14
  COLAB_INSTALL = '''
15
  !pip install -q git+https://github.com/naver/splade
16
- !pip install -q git+https://github.com/seanmacavaney/pyt_splade@misc
17
  '''.strip()
18
 
19
  def generate_vis(df, mode='Document'):
@@ -24,15 +24,9 @@ def generate_vis(df, mode='Document'):
24
  max_score = max(max(t.values()) for t in df['toks'])
25
  for row in df.itertuples(index=False):
26
  if mode == 'Query':
27
- tok_scores = {m.group(2): float(m.group(1)) for m in re.finditer(r'#combine:0=([0-9.]+)\((#base64\([^)]+\)|[^)]+)\)', row.query)}
28
- for key, value in list(tok_scores.items()):
29
- if key.startswith('#base64('):
30
- b64 = re.search('#base64\(([^)]+)\)', key).group(1)
31
- del tok_scores[key]
32
- key = base64.b64decode(b64).decode()
33
- tok_scores[key] = value
34
  max_score = max(tok_scores.values())
35
- orig_tokens = factory_max.tokenizer.tokenize(row.query_0)
36
  id = row.qid
37
  else:
38
  tok_scores = row.toks
@@ -55,38 +49,36 @@ def generate_vis(df, mode='Document'):
55
 
56
  def predict_query(input, agg):
57
  code = f'''import pandas as pd
58
- import pyterrier as pt ; pt.init()
59
  import pyt_splade
60
 
61
- splade = pyt_splade.SpladeFactory(agg={repr(agg)})
62
 
63
- query_pipeline = splade.query()
64
 
65
  query_pipeline({df2code(input)})
66
  '''
67
  pipeline = {
68
  'max': factory_max,
69
  'sum': factory_sum
70
- }[agg].query()
71
  res = pipeline(input)
72
  vis = generate_vis(res, mode='Query')
73
  return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis)
74
 
75
  def predict_doc(input, agg):
76
  code = f'''import pandas as pd
77
- import pyterrier as pt ; pt.init()
78
  import pyt_splade
79
 
80
- splade = pyt_splade.SpladeFactory(agg={repr(agg)})
81
 
82
- doc_pipeline = splade.indexing()
83
 
84
  doc_pipeline({df2code(input)})
85
  '''
86
  pipeline = {
87
  'max': factory_max,
88
  'sum': factory_sum
89
- }[agg].indexing()
90
  res = pipeline(input)
91
  vis = generate_vis(res, mode='Document')
92
  res['toks'] = [json.dumps({k: round(v, 4) for k, v in t.items()}) for t in res['toks']]
 
7
  pt.init()
8
  import pyt_splade
9
  from pyterrier_gradio import Demo, MarkdownFile, interface, df2code, code2md, EX_Q, EX_D
10
+ factory_max = pyt_splade.Splade(agg='max')
11
+ factory_sum = pyt_splade.Splade(agg='sum')
12
 
13
  COLAB_NAME = 'pyterrier_splade.ipynb'
14
  COLAB_INSTALL = '''
15
  !pip install -q git+https://github.com/naver/splade
16
+ !pip install -q git+https://github.com/cmacdonald/pyt_splade
17
  '''.strip()
18
 
19
  def generate_vis(df, mode='Document'):
 
24
  max_score = max(max(t.values()) for t in df['toks'])
25
  for row in df.itertuples(index=False):
26
  if mode == 'Query':
27
+ tok_scores = row.query_toks
28
+ orig_tokens = factory_max.tokenizer.tokenize(row.text)
 
 
 
 
 
29
  max_score = max(tok_scores.values())
 
30
  id = row.qid
31
  else:
32
  tok_scores = row.toks
 
49
 
50
  def predict_query(input, agg):
51
  code = f'''import pandas as pd
 
52
  import pyt_splade
53
 
54
+ splade = pyt_splade.Splade(agg={agg!r})
55
 
56
+ query_pipeline = splade.query_encoder()
57
 
58
  query_pipeline({df2code(input)})
59
  '''
60
  pipeline = {
61
  'max': factory_max,
62
  'sum': factory_sum
63
+ }[agg].query_encoder()
64
  res = pipeline(input)
65
  vis = generate_vis(res, mode='Query')
66
  return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis)
67
 
68
  def predict_doc(input, agg):
69
  code = f'''import pandas as pd
 
70
  import pyt_splade
71
 
72
+ splade = pyt_splade.Splade(agg={repr(agg)})
73
 
74
+ doc_pipeline = splade.doc_encoder()
75
 
76
  doc_pipeline({df2code(input)})
77
  '''
78
  pipeline = {
79
  'max': factory_max,
80
  'sum': factory_sum
81
+ }[agg].doc_encoder()
82
  res = pipeline(input)
83
  vis = generate_vis(res, mode='Document')
84
  res['toks'] = [json.dumps({k: round(v, 4) for k, v in t.items()}) for t in res['toks']]