diff --git a/.travis.yml b/.travis.yml index 91ac4185a..b0d34300f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,8 +8,12 @@ python: install: - pip install coveralls + - pip install pymongo - pip install -r requirements.txt +services: + - mongodb + script: - nosetests --with-coverage --cover-package=chatterbot diff --git a/chatterbot/adapters/storage/__init__.py b/chatterbot/adapters/storage/__init__.py index bc04eb066..9a7feac70 100644 --- a/chatterbot/adapters/storage/__init__.py +++ b/chatterbot/adapters/storage/__init__.py @@ -1,3 +1,4 @@ from .database import DatabaseAdapter from .jsondatabase import JsonDatabaseAdapter +from .mongodb import MongoDatabaseAdapter diff --git a/chatterbot/adapters/storage/mongodb.py b/chatterbot/adapters/storage/mongodb.py index ce46f9494..7b10f38d6 100644 --- a/chatterbot/adapters/storage/mongodb.py +++ b/chatterbot/adapters/storage/mongodb.py @@ -1,58 +1,112 @@ from chatterbot.adapters.storage import DatabaseAdapter +from chatterbot.adapters.exceptions import EmptyDatabaseException +from chatterbot.conversation import Statement from pymongo import MongoClient -# Use the default host and port -client = MongoClient() - -# We can also specify the host and port explicitly -#client = MongoClient('localhost', 27017) - -# Specify the name of the database -db = client['test-database'] - -# The mongo collection of statement documents -statements = db['statements'] - class MongoDatabaseAdapter(DatabaseAdapter): def __init__(self, **kwargs): - pass + super(MongoDatabaseAdapter, self).__init__(**kwargs) - def find(self, statement): - #def find(self, key): - return statements.find_one(statement) + self.database_name = self.kwargs.get("database", "chatterbot-database") - def insert(self, key, values): - statement_id = self.statements.insert_one(statement).inserted_id + # Use the default host and port + self.client = MongoClient() - return statement_id + # Specify the name of the database + self.database = self.client[self.database_name] - def update(self, key, **kwargs): + # The mongo collection of statement documents + self.statements = self.database['statements'] - values = self.database.data(key=key) + def count(self): + return self.statements.count() + + def find(self, statement_text): + values = self.statements.find_one({'text': statement_text}) - # Create the statement if it doesn't exist in the database if not values: - self.database[key] = {} - values = {} + return None + + del(values['text']) + return Statement(statement_text, **values) + def filter(self, **kwargs): + """ + Returns a list of statements in the database + that match the parameters specified. + """ + filter_parameters = kwargs.copy() + contains_parameters = {} + + # Exclude special arguments from the kwargs for parameter in kwargs: - values[parameter] = kwargs.get(parameter) + if "__" in parameter: + del(filter_parameters[parameter]) + + kwarg_parts = parameter.split("__") + + if kwarg_parts[1] == "contains": + key = kwarg_parts[0] + value = kwargs[parameter] + contains_parameters[key] = value + + filter_parameters.update(contains_parameters) + + matches = self.statements.find(filter_parameters) + matches = list(matches) - self.database[key] = values + results = [] - return values + for match in matches: + statement_text = match['text'] + del(match['text']) + results.append(Statement(statement_text, **match)) - def keys(self): - # The value has to be cast as a list for Python 3 compatibility - return list(self.database[0].keys()) + return results + + def update(self, statement): + # Do not alter the database unless writing is enabled + if not self.read_only: + data = statement.serialize() + + # Remove the text key from the data + self.statements.update({'text': statement.text}, data, True) + + # Make sure that an entry for each response is saved + for response_statement in statement.in_response_to: + response = self.find(response_statement) + if not response: + response = Statement(response_statement) + self.update(response) + + return statement def get_random(self): """ Returns a random statement from the database """ - from random import choice + from random import randint + + count = self.count() + + random_integer = randint(0, count -1) + + if self.count() < 1: + raise EmptyDatabaseException() + + statement = self.statements.find().limit(1).skip(random_integer * count) + + values = list(statement)[0] + statement_text = values['text'] + + del(values['text']) + return Statement(statement_text, **values) + + def drop(self): + """ + Remove the database. + """ + self.client.drop_database(self.database_name) - statement = choice(self.keys()) - return {statement: self.find(statement)} diff --git a/chatterbot/chatterbot.py b/chatterbot/chatterbot.py index 85b15a536..36cfa4142 100644 --- a/chatterbot/chatterbot.py +++ b/chatterbot/chatterbot.py @@ -85,12 +85,10 @@ def get_response(self, input_text): """ Return the bot's response based on the input. """ - text_of_all_statements = self.storage._keys() - input_statement = Statement(input_text) # If no responses exist, use the input text - if not text_of_all_statements: + if not self.storage.count(): response = Statement(input_text) self.storage.update(response) self.recent_statements.append(response) @@ -100,6 +98,11 @@ def get_response(self, input_text): return response + all_statements = self.storage.filter() + text_of_all_statements = [] + for statement in all_statements: + text_of_all_statements.append(statement.text) + # Select the closest match to the input statement closest_match = self.logic.get( input_text, text_of_all_statements diff --git a/examples/terminal_mongo_example.py b/examples/terminal_mongo_example.py new file mode 100644 index 000000000..46fdfc9bd --- /dev/null +++ b/examples/terminal_mongo_example.py @@ -0,0 +1,40 @@ +from chatterbot import ChatBot + + +# Create a new instance of a ChatBot +bot = ChatBot("Terminal", + storage_adapter="chatterbot.adapters.storage.MongoDatabaseAdapter", + logic_adapter="chatterbot.adapters.logic.ClosestMatchAdapter", + io_adapter="chatterbot.adapters.io.TerminalAdapter", + database="chatterbot-database") + +user_input = "Type something to begin..." + +print(user_input) + +''' +In this example we use a while loop combined with a try-except statement. +This allows us to have a conversation with the chat bot until we press +ctrl-c or ctrl-d on the keyboard. +''' + +while True: + try: + ''' + ChatterBot's get_input method uses io adapter to get new input for + the bot to respond to. In this example, the TerminalAdapter gets the + input from the user's terminal. Other io adapters might retrieve input + differently, such as from various web APIs. + ''' + user_input = bot.get_input() + + ''' + The get_response method also uses the io adapter to determine how + the bot's output should be returned. In the case of the TerminalAdapter, + the output is printed to the user's terminal. + ''' + bot_input = bot.get_response(user_input) + + except (KeyboardInterrupt, EOFError, SystemExit): + break + diff --git a/requirements.txt b/requirements.txt index 6b0942e23..9d7e979a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ fuzzywuzzy==0.6.0 +jsondatabase>=0.0.6 +pymongo==3.0.3 requests==2.7.0 requests-oauthlib==0.5.0 -jsondatabase==0.0.6 diff --git a/tests/storage_adapter_tests/test_jsondb_adapter.py b/tests/storage_adapter_tests/test_jsondb_adapter.py index 400ff943b..67cd2f00e 100644 --- a/tests/storage_adapter_tests/test_jsondb_adapter.py +++ b/tests/storage_adapter_tests/test_jsondb_adapter.py @@ -241,6 +241,20 @@ def test_filter_multiple_parameters_no_results(self): self.assertEqual(len(results), 0) + def test_filter_no_parameters(self): + """ + If not parameters are provided to the filter, + then all statements should be returned. + """ + statement1 = Statement("Testing...") + statement2 = Statement("Testing one, two, three.") + self.adapter.update(statement1) + self.adapter.update(statement2) + + results = self.adapter.filter() + + self.assertEqual(len(results), 2) + class ReadOnlyJsonDatabaseAdapterTestCase(BaseJsonDatabaseAdapterTestCase): diff --git a/tests/storage_adapter_tests/test_mongo_adapter.py b/tests/storage_adapter_tests/test_mongo_adapter.py new file mode 100644 index 000000000..11809e2f6 --- /dev/null +++ b/tests/storage_adapter_tests/test_mongo_adapter.py @@ -0,0 +1,280 @@ +from unittest import TestCase +from chatterbot.adapters.storage import MongoDatabaseAdapter +from chatterbot.conversation import Statement + + +class BaseMongoDatabaseAdapterTestCase(TestCase): + + def setUp(self): + """ + Instantiate the adapter. + """ + database_name = "test_db" + + self.adapter = MongoDatabaseAdapter(database=database_name) + + def tearDown(self): + """ + Remove the test database. + """ + self.adapter.drop() + +class JsonDatabaseAdapterTestCase(BaseMongoDatabaseAdapterTestCase): + + def test_count_returns_zero(self): + """ + The count method should return a value of 0 + when nothing has been saved to the database. + """ + self.assertEqual(self.adapter.count(), 0) + + def test_count_returns_value(self): + """ + The count method should return a value of 1 + when one item has been saved to the database. + """ + statement = Statement("Test statement") + self.adapter.update(statement) + self.assertEqual(self.adapter.count(), 1) + + def test_statement_not_found(self): + """ + Test that None is returned by the find method + when a matching statement is not found. + """ + self.assertEqual(self.adapter.find("Non-existant"), None) + + def test_statement_found(self): + """ + Test that a matching statement is returned + when it exists in the database. + """ + statement = Statement("New statement") + self.adapter.update(statement) + + found_statement = self.adapter.find("New statement") + self.assertNotEqual(found_statement, None) + self.assertEqual(found_statement.text, statement.text) + + def test_update_adds_new_statement(self): + statement = Statement("New statement") + self.adapter.update(statement) + + statement_found = self.adapter.find("New statement") + self.assertNotEqual(statement_found, None) + self.assertEqual(statement_found.text, statement.text) + + def test_update_modifies_existing_statement(self): + statement = Statement("New statement") + self.adapter.update(statement) + + # Check the initial values + found_statement = self.adapter.find(statement.text) + self.assertEqual(found_statement.occurrence, 1) + + # Update the statement value + statement.update_occurrence_count() + self.adapter.update(statement) + + # CHeck that the values have changed + found_statement = self.adapter.find(statement.text) + self.assertEqual(found_statement.occurrence, 2) + + def test_get_random_returns_statement(self): + statement = Statement("New statement") + self.adapter.update(statement) + + random_statement = self.adapter.get_random() + self.assertEqual(random_statement.text, statement.text) + + def test_find_returns_nested_responces(self): + response_list = [ + "Yes", "No" + ] + statement = Statement( + "Do you like this?", + in_response_to=response_list + ) + self.adapter.update(statement) + + result = self.adapter.find(statement.text) + + self.assertIn("Yes", result.in_response_to) + self.assertIn("No", result.in_response_to) + + + def test_filter_no_results(self): + statement1 = Statement( + "Testing...", + occurrence=4 + ) + self.adapter.update(statement1) + + results = self.adapter.filter(occurrence=100) + self.assertEqual(len(results), 0) + + def test_filter_equal_result(self): + statement1 = Statement( + "Testing...", + occurrence=22 + ) + statement2 = Statement( + "Testing one, two, three.", + occurrence=1 + ) + self.adapter.update(statement1) + self.adapter.update(statement2) + + results = self.adapter.filter(occurrence=22) + self.assertEqual(len(results), 1) + self.assertIn(statement1, results) + + def test_filter_equal_multiple_results(self): + statement1 = Statement( + "Testing...", + occurrence=6 + ) + statement2 = Statement( + "Testing one, two, three.", + occurrence=1 + ) + statement3 = Statement( + "Test statement.", + occurrence=6 + ) + self.adapter.update(statement1) + self.adapter.update(statement2) + self.adapter.update(statement3) + + results = self.adapter.filter(occurrence=6) + self.assertEqual(len(results), 2) + self.assertIn(statement1, results) + self.assertIn(statement3, results) + + def test_filter_contains_result(self): + statement1 = Statement( + "Testing...", + in_response_to=[ + "What are you doing?" + ] + ) + statement2 = Statement( + "Testing one, two, three.", + in_response_to=[ + "Testing..." + ] + ) + self.adapter.update(statement1) + self.adapter.update(statement2) + + results = self.adapter.filter( + in_response_to__contains="What are you doing?" + ) + self.assertEqual(len(results), 1) + self.assertIn(statement1, results) + + def test_filter_contains_no_result(self): + statement1 = Statement( + "Testing...", + in_response_to=[ + "What are you doing?" + ] + ) + self.adapter.update(statement1) + + results = self.adapter.filter( + in_response_to__contains="How do you do?" + ) + self.assertEqual(len(results), 0) + + def test_filter_multiple_parameters(self): + statement1 = Statement( + "Testing...", + occurrence=6, + in_response_to=[ + "Why are you counting?" + ] + ) + statement2 = Statement( + "Testing one, two, three.", + occurrence=6, + in_response_to=[ + "Testing..." + ] + ) + self.adapter.update(statement1) + self.adapter.update(statement2) + + results = self.adapter.filter( + occurrence=6, + in_response_to__contains="Why are you counting?" + ) + + self.assertEqual(len(results), 1) + self.assertIn(statement1, results) + + def test_filter_multiple_parameters_no_results(self): + statement1 = Statement( + "Testing...", + occurrence=6, + in_response_to=[ + "Why are you counting?" + ] + ) + statement2 = Statement( + "Testing one, two, three.", + occurrence=1, + in_response_to=[ + "Testing..." + ] + ) + self.adapter.update(statement1) + self.adapter.update(statement2) + + results = self.adapter.filter( + occurrence=6, + in_response_to__contains="Testing..." + ) + + self.assertEqual(len(results), 0) + + def test_filter_no_parameters(self): + """ + If not parameters are provided to the filter, + then all statements should be returned. + """ + statement1 = Statement("Testing...") + statement2 = Statement("Testing one, two, three.") + self.adapter.update(statement1) + self.adapter.update(statement2) + + results = self.adapter.filter() + + self.assertEqual(len(results), 2) + + +class ReadOnlyMongoDatabaseAdapterTestCase(BaseMongoDatabaseAdapterTestCase): + + def test_update_does_not_add_new_statement(self): + self.adapter.read_only = True + + statement = Statement("New statement") + self.adapter.update(statement) + + statement_found = self.adapter.find("New statement") + self.assertEqual(statement_found, None) + + def test_update_does_not_modify_existing_statement(self): + statement = Statement("New statement") + self.adapter.update(statement) + + self.adapter.read_only = True + + statement.update_occurrence_count() + self.adapter.update(statement) + + statement_found = self.adapter.find("New statement") + self.assertEqual(statement_found.text, statement.text) + self.assertEqual(statement.occurrence, 2) + self.assertEqual(statement_found.occurrence, 1) +