Skip to content

Commit

Permalink
Merge pull request #138 from brainless/feature/continue-chat-with-his…
Browse files Browse the repository at this point in the history
…tory_135

Continue chat with history
  • Loading branch information
brainless committed Jun 18, 2024
2 parents 349aca1 + dd0a09c commit ded3a5c
Show file tree
Hide file tree
Showing 37 changed files with 499 additions and 369 deletions.
44 changes: 32 additions & 12 deletions src-tauri/src/ai_integration/commands.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use log::error;
use tauri::State;
use tokio::task::spawn_blocking;

use crate::error::DwataError;
use crate::{
error::DwataError,
workspace::{crud::CRUDRead, DwataDb},
};

use super::models::AIModel;
use super::{models::AIModel, AIIntegration, AIProvider};

#[tauri::command]
pub async fn get_ai_model_list(_usable_only: Option<bool>) -> Result<Vec<AIModel>, DwataError> {
Expand All @@ -12,19 +16,31 @@ pub async fn get_ai_model_list(_usable_only: Option<bool>) -> Result<Vec<AIModel

#[tauri::command]
pub async fn get_ai_model_choice_list(
_usable_only: Option<bool>,
usable_only: Option<bool>,
db: State<'_, DwataDb>,
) -> Result<Vec<(String, String)>, DwataError> {
let mut result: Vec<AIModel> = vec![];
result.extend(AIModel::get_models_for_openai());
result.extend(AIModel::get_models_for_groq());
result.extend(AIModel::get_models_for_anthropic());
match AIModel::get_models_for_ollama().await {
Ok(models) => result.extend(models),
Err(err) => {
error!("Could not get Ollama models\n Error: {}", err);
if let Some(false) = usable_only {
// We load all the AI models
result.extend(AIModel::get_all_models().await);
} else {
let mut db_guard = db.lock().await;
// We load all the AI providers that are usable
let ai_integrations: Vec<AIIntegration> = AIIntegration::read_all(&mut db_guard).await?;
for ai_integration in ai_integrations {
match ai_integration.ai_provider {
AIProvider::OpenAI => result.extend(AIModel::get_models_for_openai()),
AIProvider::Groq => result.extend(AIModel::get_models_for_groq()),
AIProvider::Ollama => match AIModel::get_models_for_ollama().await {
Ok(models) => result.extend(models),
Err(err) => {
error!("Could not get Ollama models\n Error: {}", err);
}
},
}
}
}
result.extend(AIModel::get_models_for_mistral());

Ok(result
.iter()
.map(|x| {
Expand All @@ -34,7 +50,11 @@ pub async fn get_ai_model_choice_list(
x.ai_provider.clone().to_string(),
x.api_name.clone()
),
x.label.clone(),
format!(
"{} - {}",
x.ai_provider.clone().to_string(),
x.label.clone()
),
)
})
.collect())
Expand Down
12 changes: 6 additions & 6 deletions src-tauri/src/ai_integration/crud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ impl CRUDCreateUpdate for AIIntegrationCreateUpdate {
}

fn get_column_names_values(&self) -> VecColumnNameValue {
let mut name_values: VecColumnNameValue = VecColumnNameValue::default();
let mut names_values: VecColumnNameValue = VecColumnNameValue::default();
if let Some(x) = &self.label {
name_values.push_name_value("label", InputValue::Text(x.clone()));
names_values.push_name_value("label", InputValue::Text(x.clone()));
}
if let Some(x) = &self.ai_provider {
name_values.push_name_value("ai_provider", InputValue::Text(x.clone()));
names_values.push_name_value("ai_provider", InputValue::Text(x.clone()));
}
if let Some(x) = &self.api_key {
name_values.push_name_value("api_key", InputValue::Text(x.clone()));
names_values.push_name_value("api_key", InputValue::Text(x.clone()));
}
name_values.push_name_value("created_at", InputValue::DateTime(Utc::now()));
name_values
names_values.push_name_value("created_at", InputValue::DateTime(Utc::now()));
names_values
}
}

Expand Down
7 changes: 3 additions & 4 deletions src-tauri/src/ai_integration/providers/ollama.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use crate::text_generation::TextGenerationRequest;
use serde::{Deserialize, Serialize};

use super::openai::ChatRequestMessage;

#[derive(Serialize)]
pub struct OllamaTextGenerationRequest {
pub model: String,
pub messages: Vec<ChatRequestMessage>,
pub messages: Vec<TextGenerationRequest>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
}
Expand All @@ -30,7 +29,7 @@ pub struct OllamaTextGenerationRequest {
pub struct OllamaTextGenerationResponse {
pub model: String,
pub created_at: String,
pub message: ChatRequestMessage,
pub message: TextGenerationRequest,
pub done: bool,
pub total_duration: u64,
pub load_duration: u64,
Expand Down
12 changes: 4 additions & 8 deletions src-tauri/src/ai_integration/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

use crate::text_generation::TextGenerationRequest;

#[derive(Serialize, Deserialize, Debug)]
pub(crate) struct Usage {
pub(crate) prompt_tokens: i64,
Expand Down Expand Up @@ -54,16 +56,10 @@ pub(crate) struct LogProb {
pub(crate) content: Option<Vec<LogProbContent>>,
}

#[derive(Serialize, Deserialize, Debug)]
pub(crate) struct ChatRequestMessage {
pub(crate) role: String,
pub(crate) content: String,
}

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize)]
pub(crate) struct OpenAIChatRequest {
pub(crate) model: String,
pub(crate) messages: Vec<ChatRequestMessage>,
pub(crate) messages: Vec<TextGenerationRequest>,
pub(crate) tools: Vec<OpenAITool>,
}

Expand Down
2 changes: 1 addition & 1 deletion src-tauri/src/chat/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ impl Configurable for Chat {
fn get_schema() -> Configuration {
Configuration::new(
"Chat with AI",
"Chat with AI models (or other users), sharing your objectives and let AI help you find solutions",
"Chat with AI models, sharing your objectives and let AI help you find solutions",
vec![
FormField::new(
"message",
Expand Down
23 changes: 14 additions & 9 deletions src-tauri/src/chat/crud.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Chat, ChatCreateUpdate, ChatFilters};
use super::{Chat, ChatCreateUpdate, ChatFilters, Role};
use crate::{
content::content::{Content, ContentType},
error::DwataError,
Expand Down Expand Up @@ -31,25 +31,30 @@ impl CRUDCreateUpdate for ChatCreateUpdate {
}

fn get_column_names_values(&self) -> VecColumnNameValue {
let mut name_values: VecColumnNameValue = VecColumnNameValue::default();
let mut names_values: VecColumnNameValue = VecColumnNameValue::default();
if let Some(x) = &self.role {
name_values.push_name_value("role", InputValue::Text(x.clone()));
names_values.push_name_value("role", InputValue::Text(x.clone()));
}
if let Some(x) = &self.root_chat_id {
name_values.push_name_value("root_chat_id", InputValue::ID(*x));
names_values.push_name_value("root_chat_id", InputValue::ID(*x));
}
if let Some(x) = &self.message {
name_values.push_name_value("message", InputValue::Text(x.clone()));
names_values.push_name_value("message", InputValue::Text(x.clone()));
}
if let Some(x) = &self.requested_ai_model {
name_values.push_name_value("requested_ai_model", InputValue::Text(x.clone()));
names_values.push_name_value("requested_ai_model", InputValue::Text(x.clone()));
}
// if let Some(x)= &self.requested_content_format {
// name_values.push_name_value("requested_content_format", )
// }
name_values.push_name_value("is_system_chat", InputValue::Bool(false));
name_values.push_name_value("created_at", InputValue::DateTime(Utc::now()));
name_values
if let Some(x) = &self.role {
names_values.push_name_value("role", InputValue::Text(x.clone()));
} else {
names_values.push_name_value("role", InputValue::Text(Role::User.to_string()));
}
names_values.push_name_value("is_system_chat", InputValue::Bool(false));
names_values.push_name_value("created_at", InputValue::DateTime(Utc::now()));
names_values
}

async fn post_insert(
Expand Down
22 changes: 17 additions & 5 deletions src-tauri/src/chat/mod.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
use std::str::FromStr;

use chrono::{DateTime, Utc};
use std::str::FromStr;
// use crate::chat::api_types::APIChatContextNode;
use crate::error::DwataError;
use chrono::serde::ts_milliseconds;
use serde::{Deserialize, Serialize};
use sqlx::{types::Json, FromRow, Type};
use ts_rs::TS;

use crate::error::DwataError;

pub mod configuration;
pub mod crud;

#[derive(Serialize, Type, TS)]
#[derive(Deserialize, Serialize, Type, TS)]
#[sqlx(rename_all = "lowercase")]
#[serde(rename_all = "lowercase")]
#[ts(export, export_to = "../src/api_types/")]
pub enum Role {
User,
Expand All @@ -33,6 +33,16 @@ impl FromStr for Role {
}
}

impl ToString for Role {
fn to_string(&self) -> String {
match self {
Role::User => "user".to_string(),
Role::Assistant => "assistant".to_string(),
Role::System => "system".to_string(),
}
}
}

#[derive(Deserialize, Serialize, TS)]
#[ts(export, export_to = "../src/api_types/")]
pub struct ChatToolResponse {
Expand All @@ -49,6 +59,7 @@ pub struct Chat {
pub id: i64,

// This is null for the first chat
#[ts(type = "number")]
pub root_chat_id: Option<i64>,

pub role: Option<Role>,
Expand Down Expand Up @@ -84,6 +95,7 @@ pub struct Chat {
#[ts(export, rename_all = "camelCase", export_to = "../src/api_types/")]
pub struct ChatCreateUpdate {
pub role: Option<String>,
#[ts(type = "number")]
pub root_chat_id: Option<i64>,
pub message: Option<String>,
// pub requested_content_format: Option<String>,
Expand Down
20 changes: 10 additions & 10 deletions src-tauri/src/database_source/crud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,30 @@ impl CRUDCreateUpdate for DatabaseSourceCreateUpdate {
}

fn get_column_names_values(&self) -> VecColumnNameValue {
let mut name_values: VecColumnNameValue = VecColumnNameValue::default();
let mut names_values: VecColumnNameValue = VecColumnNameValue::default();
if let Some(x) = &self.label {
name_values.push_name_value("label", InputValue::Text(x.clone()));
names_values.push_name_value("label", InputValue::Text(x.clone()));
}
if let Some(x) = &self.database_type {
name_values.push_name_value("database_type", InputValue::Text(x.clone()));
names_values.push_name_value("database_type", InputValue::Text(x.clone()));
}
if let Some(x) = &self.database_name {
name_values.push_name_value("database_name", InputValue::Text(x.clone()));
names_values.push_name_value("database_name", InputValue::Text(x.clone()));
}
if let Some(x) = &self.database_host {
name_values.push_name_value("database_host", InputValue::Text(x.clone()));
names_values.push_name_value("database_host", InputValue::Text(x.clone()));
}
if let Some(x) = &self.database_port {
name_values.push_name_value("database_port", InputValue::Text(x.to_string()));
names_values.push_name_value("database_port", InputValue::Text(x.to_string()));
}
if let Some(x) = &self.database_username {
name_values.push_name_value("database_username", InputValue::Text(x.clone()));
names_values.push_name_value("database_username", InputValue::Text(x.clone()));
}
if let Some(x) = &self.database_password {
name_values.push_name_value("database_password", InputValue::Text(x.clone()));
names_values.push_name_value("database_password", InputValue::Text(x.clone()));
}
name_values.push_name_value("created_at", InputValue::DateTime(Utc::now()));
name_values
names_values.push_name_value("created_at", InputValue::DateTime(Utc::now()));
names_values
}

async fn pre_insert(&self, _db_connection: &mut SqliteConnection) -> Result<(), DwataError> {
Expand Down
16 changes: 9 additions & 7 deletions src-tauri/src/directory_source/crud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,22 @@ impl CRUDCreateUpdate for DirectorySourceCreateUpdate {
}

fn get_column_names_values(&self) -> VecColumnNameValue {
let mut name_values: VecColumnNameValue = VecColumnNameValue::default();
let mut names_values: VecColumnNameValue = VecColumnNameValue::default();
if let Some(x) = &self.path {
name_values.push_name_value("path", InputValue::Text(x.clone()));
names_values.push_name_value("path", InputValue::Text(x.clone()));
}
if let Some(x) = &self.label {
name_values.push_name_value("label", InputValue::Text(x.clone()));
names_values.push_name_value("label", InputValue::Text(x.clone()));
}
if let Some(x) = &self.include_patterns {
name_values.push_name_value("include_patterns", InputValue::Json(serde_json::json!(x)));
names_values
.push_name_value("include_patterns", InputValue::Json(serde_json::json!(x)));
}
if let Some(x) = &self.exclude_patterns {
name_values.push_name_value("exclude_patterns", InputValue::Json(serde_json::json!(x)));
names_values
.push_name_value("exclude_patterns", InputValue::Json(serde_json::json!(x)));
}
name_values.push_name_value("created_at", InputValue::DateTime(Utc::now()));
name_values
names_values.push_name_value("created_at", InputValue::DateTime(Utc::now()));
names_values
}
}
9 changes: 6 additions & 3 deletions src-tauri/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ pub enum DwataError {
CouldNotConnectToAIProvider,
CouldNotGenerateEmbedding,
FeatureNotAvailableWithAIProvider,
ChatDoesNotHaveMessage,
ChatDoesNotHaveAIModel,
ChatHasBeenProcessedByAI,

// Chat and its processing related
ChatHasNoMessage,
NoRequestedAIModel,
AlreadyProcessedByAI,
ChatHasNoRootId,

// Integrated vector DB
CouldNotConnectToVectorDB,
Expand Down
10 changes: 10 additions & 0 deletions src-tauri/src/text_generation/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@ use crate::workspace::crud::InsertUpdateResponse;
use crate::workspace::DwataDb;
use tauri::State;

/// Tauri command to generate response for a chat thread.
///
/// # Arguments
///
/// * `root_chat_id` - The ID of the root chat of this thread to generate text for.
/// * `db` - The Dwata database connection.
///
/// # Returns
///
/// * `Result<InsertUpdateResponse, DwataError>` - The response from the AI model.
#[tauri::command]
pub async fn chat_with_ai(
chat_id: i64,
Expand Down
Loading

0 comments on commit ded3a5c

Please sign in to comment.