File size: 1,606 Bytes
5fa4331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa999a3
5fa4331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
A script for a text sentiment analysis tool for the 🤗 Transformers Agent library.
"""

from transformers import Tool
from transformers.tools.base import get_default_device
from transformers import pipeline
from transformers import DistilBertTokenizerFast
from trainDistilBERT import DistilBertForMulticlassSequenceClassification
import torch



class SentAnalClassifierTool(Tool):
	"""
	A tool for sentiment analysis
	"""
	ckpt = "ongknsro/ACARISBERT-DistilBERT"
	name = "text_sentiment_analyzer"
	description = (
		"This is a tool that returns a sentiment label for a given text sequence. "
		"It takes raw text as input, and "
		"returns a sentiment label as output."
	)

	inputs = ["text"]
	outputs = ["text"]

	def __init__(self, device=None, **hub_kwargs) -> None:
		super().__init__()

		self.device = device
		self.pipeline = None
		self.hub_kwargs = hub_kwargs

	def setup(self):
		if self.device is None:
			self.device = get_default_device()

		self.tokenizer = DistilBertTokenizerFast.from_pretrained(self.ckpt)

		self.model = DistilBertForMulticlassSequenceClassification.from_pretrained(self.ckpt).to(self.device)

		self.pipeline = pipeline("sentiment-analysis", model=self.model, tokenizer=self.tokenizer, top_k=None, device=0)

		self.is_initialized = True

	def __call__(self, task: str):
		if not self.is_initialized:
			self.setup()

		outputs = self.pipeline(task)
		labels = [item["label"] for item in outputs[0]]
		logits = [item["score"] for item in outputs[0]]
		probs = torch.softmax(torch.tensor(logits), dim=0)
		label = labels[torch.argmax(probs).item()]

		return label