qdrant service
This commit is contained in:
parent
61749d94ed
commit
1828ae5693
25
src/app.py
25
src/app.py
@ -1,19 +1,34 @@
|
||||
from flask import Flask, jsonify
|
||||
from flask import Flask, jsonify, request
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from services.embedding_service import Embedding
|
||||
|
||||
|
||||
load_dotenv()
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/")
|
||||
def hello():
|
||||
return jsonify({"message":"Hello, World!"})
|
||||
# @app.route("/")
|
||||
# def hello():
|
||||
# return jsonify({"message":"Hello, World!"})
|
||||
|
||||
@app.get("/")
|
||||
def hello():
|
||||
return jsonify({"message":"Hello, World!"})
|
||||
|
||||
|
||||
@app.post("/get_embedding")
|
||||
def get_embedding():
|
||||
# Get the query text from request JSON
|
||||
data = request.get_json()
|
||||
query = data.get("query", "")
|
||||
|
||||
if not query:
|
||||
return jsonify({"error": "Query text is required"}), 400
|
||||
|
||||
# Call the embedding function
|
||||
vector = Embedding.call(query, is_query=True)
|
||||
|
||||
return jsonify({"query": query, "number of embedding": len(vector), "one embedding":len(vector[0])})
|
||||
|
||||
print(__name__)
|
||||
|
||||
if __name__=="__main__":
|
||||
|
||||
5
src/qdrant_services/policy.py
Normal file
5
src/qdrant_services/policy.py
Normal file
@ -0,0 +1,5 @@
|
||||
from enum import Enum
|
||||
|
||||
class QdrantClientSelectionPolicy(Enum):
|
||||
RANDOM = 1,
|
||||
ROUND_ROBIN = 2,
|
||||
126
src/qdrant_services/qdrant.py
Normal file
126
src/qdrant_services/qdrant.py
Normal file
@ -0,0 +1,126 @@
|
||||
import inspect
|
||||
import logging
|
||||
import random
|
||||
from functools import wraps
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
from qdrant_services.settings import settings
|
||||
from services.service import Singleton
|
||||
|
||||
from qdrant_services.policy import QdrantClientSelectionPolicy
|
||||
|
||||
|
||||
class Qdrant(metaclass=Singleton):
|
||||
logger = logging.getLogger(__name__)
|
||||
QDRANT_DEFAULT_PORT = 6333
|
||||
|
||||
def __init__(self,
|
||||
host=None,
|
||||
port=None,
|
||||
collection_name=None,
|
||||
cluster=None,
|
||||
cluster_selection_policy=QdrantClientSelectionPolicy.ROUND_ROBIN,
|
||||
**kwargs
|
||||
):
|
||||
self.collection_name = collection_name or self.__class__.get_default_collection_name()
|
||||
|
||||
if settings.QDRANT_HTTPS is not None and kwargs.get('https') is None:
|
||||
kwargs['https'] = settings.QDRANT_HTTPS
|
||||
if settings.QDRANT_SSL_VERIFY is not None and kwargs.get('verify') is None:
|
||||
kwargs['verify'] = settings.QDRANT_SSL_VERIFY
|
||||
if settings.QDRANT_API_KEY is not None and kwargs.get('api_key') is None:
|
||||
kwargs['api_key'] = settings.QDRANT_API_KEY
|
||||
|
||||
host = host or settings.QDRANT_HOST
|
||||
port = port or settings.QDRANT_PORT
|
||||
cluster = cluster or settings.QDRANT_CLUSTER
|
||||
if cluster:
|
||||
self.logger.info(
|
||||
"Connecting to qdrant cluster: cluster: %s, collection_name: %s" % (cluster, self.collection_name))
|
||||
hosts = cluster.split(",")
|
||||
self._clients = []
|
||||
for host_port in hosts:
|
||||
host, port = host_port.split(":")
|
||||
port = port or self.QDRANT_DEFAULT_PORT
|
||||
port = int(port)
|
||||
self._clients.append(self.QdrantClientProxy(self, host=host, port=port, **kwargs))
|
||||
if len(self._clients) == 0:
|
||||
raise ValueError("No qdrant hosts provided")
|
||||
if len(self._clients) == 1:
|
||||
self.logger.warning("Only one host provided for cluster, using it as a single client")
|
||||
self._client = self._clients[0]
|
||||
self._clients = None
|
||||
self.cluster_selection_policy = cluster_selection_policy
|
||||
self._round_robin_index = -1
|
||||
else:
|
||||
self.logger.info(
|
||||
"Connecting to qdrant: host: %s, port: %s, collection_name: %s" % (host, port, self.collection_name))
|
||||
self._client = self.QdrantClientProxy(self, host=host, port=port, **kwargs)
|
||||
|
||||
def _get_client(self) -> QdrantClient:
|
||||
if hasattr(self, "_client"):
|
||||
return self._client
|
||||
clients = self._clients
|
||||
if self.cluster_selection_policy == QdrantClientSelectionPolicy.RANDOM:
|
||||
return random.choice(clients)
|
||||
else:
|
||||
self._round_robin_index = (self._round_robin_index + 1) % len(clients)
|
||||
return clients[self._round_robin_index]
|
||||
|
||||
@classmethod
|
||||
def get_client(cls, host=None, port=None, collection_name=None, cluster=None) -> QdrantClient:
|
||||
proxy = cls(host, port, collection_name, cluster)
|
||||
return proxy._get_client()
|
||||
|
||||
@classmethod
|
||||
def get_all_clients(cls, cluster=None, collection_name=None) -> [QdrantClient]:
|
||||
proxy = cls(collection_name=collection_name, cluster=cluster)
|
||||
if hasattr(proxy, "_client"):
|
||||
return [proxy._client]
|
||||
return proxy._clients
|
||||
|
||||
@classmethod
|
||||
def get_default_collection_name(cls) -> str:
|
||||
return settings.QDRANT_COLLECTION_NAME
|
||||
|
||||
class QdrantClientProxy(QdrantClient):
|
||||
|
||||
def __init__(self, outer_instance, *args, **kwargs):
|
||||
self.outer_instance = outer_instance
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
_excluded_methods = ['_dynamic_call_decorator']
|
||||
|
||||
def __getattribute__(self, item):
|
||||
attr = super().__getattribute__(item)
|
||||
if callable(attr) and item not in self._excluded_methods:
|
||||
return self._dynamic_call_decorator(attr)
|
||||
return attr
|
||||
|
||||
def _dynamic_call_decorator(self, func):
|
||||
METHODS_MODIFIED = []
|
||||
METHODS_EXCLUDED = []
|
||||
|
||||
def check_need_collection_name():
|
||||
sig = inspect.signature(func)
|
||||
params = sig.parameters
|
||||
kwargs = [param.name for param in params.values() if param.default == inspect.Parameter.empty]
|
||||
return "collection_name" in kwargs
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
while True:
|
||||
if "collection_name" in kwargs or \
|
||||
func.__name__ in METHODS_EXCLUDED:
|
||||
break
|
||||
if check_need_collection_name():
|
||||
METHODS_MODIFIED.append(func.__name__)
|
||||
else:
|
||||
METHODS_EXCLUDED.append(func.__name__)
|
||||
break
|
||||
kwargs["collection_name"] = self.outer_instance.collection_name
|
||||
break
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
2
src/qdrant_services/settings.py
Normal file
2
src/qdrant_services/settings.py
Normal file
@ -0,0 +1,2 @@
|
||||
QDRANT_HOST = "192.168.99.122"
|
||||
QDRANT_PORT = 6333
|
||||
@ -7,7 +7,7 @@ import os
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
# from app.settings import MODEL_FOLDER, TOKENIZER_FOLDER
|
||||
from core.service import Component
|
||||
from services.service import Component
|
||||
TOKENIZER_FOLDER = os.getenv("TOKENIZER_FOLDER")
|
||||
MODEL_FOLDER = os.getenv("MODEL_FOLDER")
|
||||
class Embedding(Component):
|
||||
@ -23,10 +23,13 @@ class Embedding(Component):
|
||||
# model = AutoModel.from_pretrained('sentence-transformers/msmarco-distilbert-base-v2')
|
||||
|
||||
# Load model from HuggingFace Hub
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_FOLDER)
|
||||
self.model = AutoModel.from_pretrained(MODEL_FOLDER)
|
||||
|
||||
|
||||
def process(self, input_clause, is_query=False):
|
||||
|
||||
return self.embedding_a_clause(input_clause, is_query)
|
||||
|
||||
def embedding_a_clause(self, input_clause, is_query=False):
|
||||
@ -52,6 +55,7 @@ class Embedding(Component):
|
||||
short_strs = [str]
|
||||
"""
|
||||
str_list = str.split('\n')
|
||||
|
||||
short_strs = []
|
||||
for str_item in str_list:
|
||||
str_item = re.sub(r'\([a-z]\)\.?', '', str_item)
|
||||
@ -60,6 +64,7 @@ class Embedding(Component):
|
||||
short_strs.extend(nltk.sent_tokenize(str_item))
|
||||
else:
|
||||
short_strs.append(str_item)
|
||||
|
||||
clause_segments = []
|
||||
temp_str = ""
|
||||
for short_str in short_strs:
|
||||
@ -75,7 +80,11 @@ class Embedding(Component):
|
||||
clause_segments[-1] = clause_segments[-1] + ' ' + temp_str
|
||||
else:
|
||||
clause_segments.append(temp_str)
|
||||
|
||||
# print("in embedding_service line 82, clause_segments: ", clause_segments)
|
||||
# print("clause_segments len: ", len(clause_segments))
|
||||
# print(len(clause_segments[0]))
|
||||
# print(len(clause_segments[1]))
|
||||
# print(len(clause_segments[2]))
|
||||
clause_vectors = self.embedding_sentences(clause_segments)
|
||||
for i in range(len(clause_segments)):
|
||||
if len(clause_segments[i]) < segment_threshold:
|
||||
@ -83,6 +92,7 @@ class Embedding(Component):
|
||||
###
|
||||
clause_vectors[i] /= np.linalg.norm(clause_vectors[i])
|
||||
###
|
||||
|
||||
return clause_vectors
|
||||
|
||||
def embedding_sentences(self, input_sentences):
|
||||
@ -95,11 +105,11 @@ class Embedding(Component):
|
||||
"""
|
||||
# Tokenize sentences
|
||||
encoded_input = self.tokenizer(input_sentences, padding=True, truncation=True, return_tensors='pt')
|
||||
|
||||
|
||||
# Compute token embeddings
|
||||
with torch.no_grad():
|
||||
model_output = self.model(**encoded_input)
|
||||
|
||||
|
||||
# Perform pooling. In this case, max pooling.
|
||||
sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
|
||||
|
||||
|
||||
0
src/services/recommend_service.py
Normal file
0
src/services/recommend_service.py
Normal file
96
src/services/service.py
Normal file
96
src/services/service.py
Normal file
@ -0,0 +1,96 @@
|
||||
import logging
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# container contains components, components contains class
|
||||
class Container:
|
||||
def __init__(self):
|
||||
self.instances = {}
|
||||
|
||||
def get_component(self, name):
|
||||
return self.instances[name]
|
||||
|
||||
def has_component(self, name):
|
||||
return name in self.instances
|
||||
|
||||
def set_component(self, name, component):
|
||||
self.instances[name] = component
|
||||
|
||||
|
||||
container = Container()
|
||||
|
||||
|
||||
class Component(object):
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_class(cls):
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def get_name(cls):
|
||||
return "%s.%s" % (cls.__module__, cls.__name__)
|
||||
|
||||
@classmethod
|
||||
def _get_instance_from_container(cls):
|
||||
# container = get_container_instance()
|
||||
|
||||
name = cls.get_name()
|
||||
|
||||
if not container.has_component(name):
|
||||
|
||||
instance = cls()
|
||||
container.set_component(name, instance)
|
||||
instance.load()
|
||||
|
||||
|
||||
return container.get_component(name)
|
||||
|
||||
@classmethod
|
||||
def prepare(cls):
|
||||
"""
|
||||
Prepare this component.
|
||||
It calls `load` if this hasn't been done.
|
||||
"""
|
||||
return cls._get_instance_from_container()
|
||||
|
||||
@classmethod
|
||||
def call(cls, *args, **kwargs):
|
||||
|
||||
"""
|
||||
Call `process` and return the results.
|
||||
It calls `load` if this hasn't been done before calling `process`
|
||||
"""
|
||||
return cls._get_instance_from_container().process(*args, **kwargs)
|
||||
|
||||
def load(self):
|
||||
"""
|
||||
Called when instance initialized, usually to load something can be shared between different calls.
|
||||
It's recommend for each module/app itself to load if it takes very long time
|
||||
It's guaranteed that the function MUST be called before any call to process.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
called when application shutdown
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
Loading…
Reference in New Issue
Block a user