AkashKhamkar commited on
Commit
31c60be
1 Parent(s): 1804b9e

Upload segmentation.py

Browse files
Files changed (1) hide show
  1. segmentation.py +107 -0
segmentation.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ import attr
3
+ import pandas as pd
4
+ import numpy as np
5
+ import spacy
6
+ from nltk.tokenize.texttiling import TextTilingTokenizer
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ @lru_cache
11
+ def load_sentence_transformer(model_name='all-MiniLM-L6-v2'):
12
+ """
13
+ all_MiniLM_L6_v2 - offline
14
+ all-MiniLM-L6-v2 - Online
15
+ """
16
+ model = SentenceTransformer(model_name)
17
+ return model
18
+
19
+ @lru_cache
20
+ def load_spacy():
21
+ return spacy.load('en_core_web_sm')
22
+
23
+
24
+ model = load_sentence_transformer()
25
+ nlp = load_spacy()
26
+
27
+
28
+ @attr.s
29
+ class SemanticTextSegmentation:
30
+
31
+ """
32
+ Segment a call transcript based on topics discussed in the call using
33
+ TextTilling with Sentence Similarity via sentence transformer.
34
+ Paramters
35
+ ---------
36
+ data: pd.Dataframe
37
+ Pass the trascript in the dataframe format
38
+ utterance: str
39
+ pass the column name which represent utterance in transcript dataframe
40
+ """
41
+
42
+ data = attr.ib()
43
+ utterance = attr.ib(default='utterance')
44
+
45
+ def __attrs_post_init__(self):
46
+ columns = self.data.columns.tolist()
47
+
48
+ def get_segments(self, threshold=0.7):
49
+ """
50
+ returns the transcript segments computed with texttiling and sentence-transformer.
51
+ Paramters
52
+ ---------
53
+ threshold: float
54
+ sentence similarity threshold. (used to merge the sentences into coherant segments)
55
+ Return
56
+ ------
57
+ new_segments: list
58
+ list of segments
59
+ """
60
+ segments = self._text_tilling()
61
+ merge_index = self._merge_segments(segments, threshold)
62
+ new_segments = []
63
+ for i in merge_index:
64
+ seg = ' '.join([segments[_] for _ in i])
65
+ new_segments.append(seg)
66
+ return new_segments
67
+
68
+ def _merge_segments(self, segments, threshold):
69
+ segment_map = [0]
70
+ for index, (text1, text2) in enumerate(zip(segments[:-1], segments[1:])):
71
+ sim = self._get_similarity(text1, text2)
72
+ if sim >= threshold:
73
+ segment_map.append(0)
74
+ else:
75
+ segment_map.append(1)
76
+ return self._index_mapping(segment_map)
77
+
78
+ def _index_mapping(self, segment_map):
79
+ index_list = []
80
+ temp = []
81
+ for index, i in enumerate(segment_map):
82
+ if i == 1:
83
+ index_list.append(temp)
84
+ temp = [index]
85
+ else:
86
+ temp.append(index)
87
+ index_list.append(temp)
88
+ return index_list
89
+
90
+ def _get_similarity(self, text1, text2):
91
+ sentence_1 = [i.text.strip()
92
+ for i in nlp(text1).sents if len(i.text.split(' ')) > 1]
93
+ sentence_2 = [i.text.strip()
94
+ for i in nlp(text2).sents if len(i.text.split(' ')) > 2]
95
+ embeding_1 = model.encode(sentence_1)
96
+ embeding_2 = model.encode(sentence_2)
97
+ embeding_1 = np.mean(embeding_1, axis=0).reshape(1, -1)
98
+ embeding_2 = np.mean(embeding_2, axis=0).reshape(1, -1)
99
+ sim = cosine_similarity(embeding_1, embeding_2)
100
+ return sim
101
+
102
+ def _text_tilling(self):
103
+ tt = TextTilingTokenizer(w=15, k=10)
104
+ text = '\n\n\t'.join(self.data[self.utterance].tolist())
105
+ segment = tt.tokenize(text)
106
+ segment = [i.replace("\n\n\t", ' ') for i in segment]
107
+ return segment