-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ChromaDB Support for Python SDK #110
base: main
Are you sure you want to change the base?
Conversation
Just wanted to let you know that I've been busy lately with my day job and probably won't be able to get to this for at least a week. |
Sounds good, and thank you for letting me know @ristomcgehee! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I finally found some time to review this PR. Here are my thoughts!
python-sdk/README.md
Outdated
@@ -45,14 +45,18 @@ pip install rebuff | |||
|
|||
### Detect prompt injection on user input | |||
|
|||
For vector database, Rebuff supports Pinecone (default) and Chroma. To use Chroma, install Rebuff with extras: `pip install rebuff[chromadb]` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For vector database, Rebuff supports Pinecone (default) and Chroma. To use Chroma, install Rebuff with extras: `pip install rebuff[chromadb]` | |
For vector database, Rebuff supports Pinecone (default) and Chroma. |
That same information is repeated a few lines later, so I don't think we need it here.
input: str, similarity_threshold: float, vector_store: Pinecone | ||
input: str, | ||
similarity_threshold: float, | ||
vector_store: Union[Pinecone, Optional[Chroma]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vector_store: Union[Pinecone, Optional[Chroma]], | |
vector_store: VectorStore, |
And then you'd need to import VectorStore
from langchain at the top of the file.
try: | ||
import chromadb | ||
|
||
chromadb_installed = True | ||
except ImportError: | ||
print( | ||
"To use Chromadb, please install rebuff with rebuff extras. 'pip install \"rebuff[chromadb]\"'" | ||
) | ||
chromadb_installed = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to suggest a different approach for handling the fact that chromadb
might not be installed. If I understand the code correctly, even if a user is using Pinecone, they'll always see the warning "To use Chromadb, please install...". Also, it's a bit unusual conditionally defining classes and methods.
I would move ChromaCosineSimilarity
to a different file but keep init_chroma
in this file. At the beginning of init_chroma
, you import chromadb
and ChromaCosineSimilarity
within a try-except. If the import fails, then you display "To use Chromadb, please install...". Then you'd no longer need the chromadb_installed
variable.
python-sdk/rebuff/sdk.py
Outdated
) | ||
|
||
elif self.vector_db.name == "CHROMA": | ||
from rebuff.detect_pi_vectorbase import init_chroma |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you take my suggestion for refactoring detect_pi_vectorbase.py
, you'll be able to move this import to the top of the file.
python-sdk/rebuff/sdk.py
Outdated
@@ -83,7 +118,7 @@ def detect_injection( | |||
rebuff_heuristic_score = 0 | |||
|
|||
if check_vector: | |||
self.initialize_pinecone() | |||
self.initialize_vector_store() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.initialize_vector_store() | |
if self.vector_store is None: | |
self.initialize_vector_store() |
python-sdk/tests/test_sdk.py
Outdated
def rebuff(request) -> RebuffSdk: | ||
rb = RebuffSdk( | ||
get_environment_variable("OPENAI_API_KEY"), | ||
request.param, | ||
get_environment_variable("PINECONE_API_KEY"), | ||
get_environment_variable("PINECONE_INDEX_NAME"), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For most test methods, it doesn't do any good testing on both pinecone and chroma; only test_detect_injection_vectorbase
needs to test with both. You could do this in the fixture:
def rebuff(request) -> RebuffSdk: | |
rb = RebuffSdk( | |
get_environment_variable("OPENAI_API_KEY"), | |
request.param, | |
get_environment_variable("PINECONE_API_KEY"), | |
get_environment_variable("PINECONE_INDEX_NAME"), | |
) | |
def rebuff(request) -> RebuffSdk: | |
vector_db = request.param if hasattr(request, "param") else VectorDB.PINECONE | |
rb = RebuffSdk( | |
get_environment_variable("OPENAI_API_KEY"), | |
vector_db, | |
get_environment_variable("PINECONE_API_KEY"), | |
get_environment_variable("PINECONE_INDEX_NAME"), | |
) |
Which would allow you to delete:
@pytest.mark.parametrize(
"rebuff",
[VectorDB.PINECONE, VectorDB.CHROMA],
ids=["pinecone", "chroma"],
indirect=True,
)
from most of the methods.
python-sdk/tests/test_sdk.py
Outdated
def test_detect_injection_vectorbase( | ||
rebuff: RebuffSdk, | ||
add_documents_to_chroma, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This parameter isn't getting used in this method, it probably shouldn't be a parameter. What I would probably do is add add_documents_to_chroma
to the end of the rebuff
function fixture.
) | ||
|
||
chroma_collection = ChromaCosineSimilarity( | ||
client=chromadb.Client(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When chromadb.Client()
is invoked, is that creating an in-memory database that only persists as long as the process does? If that's the case, it wouldn't really work for a production use case. I think what we'd want to do is use chromadb.HttpClient
to connect to a remote server. If we go that route, it looks like we might be able to use chromadb-client
instead which is a more lightweight version of chromadb
(https://docs.trychroma.com/usage-guide?lang=py#using-the-python-http-only-client).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the suggestion! I have updated the code to use chromadb.HttpClient
. I will check if we can also update the dependency to chromadb-client
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the dependency to chromadb-client
.
Thank you @ristomcgehee for the review, and suggestions! I have tried to incorporate most of them. Also using |
python-sdk/tests/test_sdk.py
Outdated
get_environment_variable("PINECONE_API_KEY"), | ||
get_environment_variable("PINECONE_INDEX_NAME"), | ||
) | ||
if hasattr(request, "param") and request.param == VectorDB.CHROMA: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if hasattr(request, "param") and request.param == VectorDB.CHROMA: | |
if vector_db == VectorDB.CHROMA: |
python-sdk/rebuff/sdk.py
Outdated
if self.vector_db.name == "PINECONE": | ||
self.pinecone_apikey = pinecone_apikey | ||
self.pinecone_index = pinecone_index | ||
|
||
elif self.vector_db.name == "CHROMA": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.vector_db.name == "PINECONE": | |
self.pinecone_apikey = pinecone_apikey | |
self.pinecone_index = pinecone_index | |
elif self.vector_db.name == "CHROMA": | |
if self.vector_db == VectorDB.PINECONE: | |
self.pinecone_apikey = pinecone_apikey | |
self.pinecone_index = pinecone_index | |
elif self.vector_db == VectorDB.CHROMA: |
Similar recommendation for within initialize_vector_store()
.
python-sdk/rebuff/sdk.py
Outdated
@@ -83,7 +114,8 @@ def detect_injection( | |||
rebuff_heuristic_score = 0 | |||
|
|||
if check_vector: | |||
self.initialize_pinecone() | |||
if self.initialize_vector_store() is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.initialize_vector_store() is None: | |
if self.vector_store is None: |
Thank you @ristomcgehee, I have now added Docker files for Chroma server. Though not sure why the JS and Python tests (integration tests) are failing. They are detecting prompt injection when there is none |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though not sure why the JS and Python tests (integration tests) are failing. They are detecting prompt injection when there is none
I believe this occurs sometimes because LLMs are non-deterministic. Sometimes, you'll give a benign input and it will give a score of 0.6
or 0.8
. A couple ways we could address that:
- Set the temperature to 0 when calling OpenAI
- Retry the tests multiple times when they fail
For now, you could also just re-run the tests and they'll likely pass.
@@ -32,17 +34,48 @@ if result.injection_detected: | |||
print("Possible injection detected. Take corrective action.") | |||
``` | |||
|
|||
#### Chroma vector database |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a quickstart page, the simpler you can make it the better. I'd recommend taking out the pinecone section and just show how to use the SDK with Chroma DB (since it requires less setup than Pinecone).
|
||
application: | ||
env_file: | ||
- .env |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to mention that an .env
file is necessary in documentation, as well as describe what is necessary to be included. Something that projects often do is have an example.env
file in the repo that people can copy and fill in with their values.
Thank you for the suggestions. I have tried rerunning the tests multiple times, and have also set temperature to 0 when calling OpenAI, thought don't think it is helping much. Python SDK tests are also failing because of connection error with chroma server when they do pass locally. I will continue to debug this, but if you have any suggestion please do share. |
if vector_db == VectorDB.CHROMA: | ||
rb = RebuffSdk(get_environment_variable("OPENAI_API_KEY"), vector_db) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't these two lines unnecessary since we already initialized rb
on line 10?
documents = dataset["train"]["text"][:200] | ||
|
||
metadatas = [{"source": "Rebuff"}] * len(documents) | ||
documents_ids = [str(uuid.uuid1()) for i in range(1, len(documents) + 1)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor simplification:
documents_ids = [str(uuid.uuid1()) for i in range(1, len(documents) + 1)] | |
documents_ids = [str(uuid.uuid1()) for _ in range(len(documents))] |
chroma_collections = [ | ||
collection for collection in chroma_client.list_collections() | ||
] | ||
if chroma_collections: | ||
chroma_collections_names = [ | ||
collection.name for collection in chroma_collections | ||
] | ||
|
||
if rebuff_collection_name in chroma_collections_names: | ||
document_collection = chroma_client.get_collection( | ||
name=rebuff_collection_name, | ||
embedding_function=openai_embedding_function, | ||
) | ||
|
||
else: | ||
document_collection = chroma_client.create_collection( | ||
name=rebuff_collection_name, | ||
metadata={"hnsw:space": "cosine"}, | ||
embedding_function=openai_embedding_function, | ||
) | ||
|
||
document_collection.add( | ||
documents=documents, metadatas=metadatas, ids=documents_ids | ||
) | ||
|
||
else: | ||
document_collection = chroma_client.create_collection( | ||
name=rebuff_collection_name, | ||
metadata={"hnsw:space": "cosine"}, | ||
embedding_function=openai_embedding_function, | ||
) | ||
|
||
document_collection.add( | ||
documents=documents, metadatas=metadatas, ids=documents_ids | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a function get_or_create_collection, so you can simplify all this to:
chroma_collections = [ | |
collection for collection in chroma_client.list_collections() | |
] | |
if chroma_collections: | |
chroma_collections_names = [ | |
collection.name for collection in chroma_collections | |
] | |
if rebuff_collection_name in chroma_collections_names: | |
document_collection = chroma_client.get_collection( | |
name=rebuff_collection_name, | |
embedding_function=openai_embedding_function, | |
) | |
else: | |
document_collection = chroma_client.create_collection( | |
name=rebuff_collection_name, | |
metadata={"hnsw:space": "cosine"}, | |
embedding_function=openai_embedding_function, | |
) | |
document_collection.add( | |
documents=documents, metadatas=metadatas, ids=documents_ids | |
) | |
else: | |
document_collection = chroma_client.create_collection( | |
name=rebuff_collection_name, | |
metadata={"hnsw:space": "cosine"}, | |
embedding_function=openai_embedding_function, | |
) | |
document_collection.add( | |
documents=documents, metadatas=metadatas, ids=documents_ids | |
) | |
document_collection = chroma_client.get_or_create_collection( | |
name=rebuff_collection_name, | |
metadata={"hnsw:space": "cosine"}, | |
embedding_function=openai_embedding_function, | |
) | |
document_collection.add( | |
documents=documents, metadatas=metadatas, ids=documents_ids | |
) |
|
||
# Wait for the documents to be added to the collection | ||
count = 0 | ||
while document_collection.count() == 0 and count < 5: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to wait here? Does document_collection.add()
return before the documents get fully committed to the data store? If so, that would be a useful thing to mention in a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you give this file a name that matches its purpose better, e.g. "add_chroma_docs.py"?
|
||
collection_status = False | ||
count = 0 | ||
while not collection_status and count <= 5: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this while loop necessary? Does creating the collection or add the documents often fail?
) | ||
collection_status = True | ||
except Exception as e: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If chroma_client.get_or_create_collection
or document_collection.add
is consistently failing, it would make debugging harder since we're not showing any error messages. Maybe you could print e
along with an error description outside of the while loop if collection_status == False
.
This PR adds Chroma DB support for Python SDK.