Skip to content

Commit

Permalink
refactor: add a base class LLMStreamService (#561)
Browse files Browse the repository at this point in the history
* refactor: add a base class LLMStreamService

* refactor: change GeminiService to inherit from LLMStreamService

* perf: remove #available macos-12 in Gemini

* chore: update swift lint

* perf: remove unused swiftlint:disable

* perf: improve structure between LLMStreamService and BaseOpenAIService

* perf: make subclass must override properties availableModels, apiKey and endpoint

* perf: mark model as must be overridden
  • Loading branch information
tisfeng committed May 25, 2024
1 parent a55487c commit 4d9eb5a
Show file tree
Hide file tree
Showing 21 changed files with 251 additions and 286 deletions.
5 changes: 4 additions & 1 deletion .swiftlint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ disabled_rules:
- force_try
- large_tuple
- todo
- no_fallthrough_only

opt_in_rules:
- convenience_type
Expand All @@ -33,6 +34,7 @@ opt_in_rules:
line_length:
warning: 120
ignores_comments: true
ignores_interpolated_strings: true
function_body_length:
warning: 120
error: 400
Expand All @@ -48,7 +50,7 @@ type_name:
warning: 50
error: 50
identifier_name:
min_length: 3
min_length: 2
excluded: # excluded via string array
- id
- URL
Expand All @@ -58,6 +60,7 @@ identifier_name:
- i
- j
- Defaults # Make use of `SwiftyUserDefaults`
- to
reporter: "xcode" # reporter type (xcode, json, csv, checkstyle, junit, html, emoji)
trailing_comma:
severity: warning
Expand Down
4 changes: 4 additions & 0 deletions Easydict.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
0383914F292FBE120009828C /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 03839145292FBE120009828C /* Main.storyboard */; };
03839150292FBE120009828C /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 03839147292FBE120009828C /* main.m */; };
03839151292FBE120009828C /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 03839148292FBE120009828C /* AppDelegate.m */; };
0387FB7A2BFBA990000A7A82 /* LLMStreamService.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0387FB792BFBA990000A7A82 /* LLMStreamService.swift */; };
03882F8D29D95044005B5A52 /* CTView.m in Sources */ = {isa = PBXBuildFile; fileRef = 03882F8429D95044005B5A52 /* CTView.m */; };
03882F8E29D95044005B5A52 /* ToastWindowController.m in Sources */ = {isa = PBXBuildFile; fileRef = 03882F8629D95044005B5A52 /* ToastWindowController.m */; };
03882F8F29D95044005B5A52 /* CTScreen.m in Sources */ = {isa = PBXBuildFile; fileRef = 03882F8729D95044005B5A52 /* CTScreen.m */; };
Expand Down Expand Up @@ -492,6 +493,7 @@
03839149292FBE120009828C /* EasydictHelper.entitlements */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.entitlements; path = EasydictHelper.entitlements; sourceTree = "<group>"; };
0383914A292FBE120009828C /* Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
0383914B292FBE120009828C /* ViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ViewController.h; sourceTree = "<group>"; };
0387FB792BFBA990000A7A82 /* LLMStreamService.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LLMStreamService.swift; sourceTree = "<group>"; };
03882F8229D95044005B5A52 /* CTScreen.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CTScreen.h; sourceTree = "<group>"; };
03882F8329D95044005B5A52 /* ToastWindowController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ToastWindowController.h; sourceTree = "<group>"; };
03882F8429D95044005B5A52 /* CTView.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = CTView.m; sourceTree = "<group>"; };
Expand Down Expand Up @@ -1270,6 +1272,7 @@
03779F0D2BB256A7008D3C42 /* OpenAI */ = {
isa = PBXGroup;
children = (
0387FB792BFBA990000A7A82 /* LLMStreamService.swift */,
0396DE542BB5844A009FD2A5 /* BaseOpenAIService.swift */,
03779F0B2BB256A7008D3C42 /* OpenAIService.swift */,
03779F0C2BB256A7008D3C42 /* Prompt.swift */,
Expand Down Expand Up @@ -3192,6 +3195,7 @@
0309E1ED292B439A00AFB76A /* EZTextView.m in Sources */,
03B0232B29231FA6001C7E63 /* NSMutableAttributedString+MM.m in Sources */,
03B022E829231FA6001C7E63 /* entry.m in Sources */,
0387FB7A2BFBA990000A7A82 /* LLMStreamService.swift in Sources */,
039F5504294B6E29004AB940 /* EZPreferencesWindowController.m in Sources */,
03008B3F29444B0A0062B821 /* NSView+EZAnimatedHidden.m in Sources */,
03B022FD29231FA6001C7E63 /* EZFixedQueryWindow.m in Sources */,
Expand Down
5 changes: 2 additions & 3 deletions Easydict/Swift/Service/Ali/AliResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,8 @@ enum AnyCodable: Codable {
switch self {
case let .int(i):
String(i)
// swiftlint:disable:next identifier_name
case let .string(s):
s
case let .string(str):
str
}
}

Expand Down
5 changes: 1 addition & 4 deletions Easydict/Swift/Service/Ali/AliService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
// Copyright © 2023 izual. All rights reserved.
//

// swiftlint:disable all

import Alamofire
import CryptoKit
import Defaults
Expand Down Expand Up @@ -140,6 +138,7 @@ class AliService: QueryService {
}
}

// swiftlint:disable:next function_parameter_count
private func requestByAPI(
id: String,
secret: String,
Expand Down Expand Up @@ -346,5 +345,3 @@ class AliService: QueryService {
}, serviceType: serviceType().rawValue)
}
}

// swiftlint:enable all
4 changes: 0 additions & 4 deletions Easydict/Swift/Service/Ali/AliTranslateType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
// Copyright © 2023 izual. All rights reserved.
//

// swiftlint:disable all

import Foundation

struct AliTranslateType: Equatable {
Expand Down Expand Up @@ -103,5 +101,3 @@ struct AliTranslateType: Equatable {
return AliTranslateType(sourceLanguage: fromLanguage, targetLanguage: toLanguage)
}
}

// swiftlint:enable all
1 change: 0 additions & 1 deletion Easydict/Swift/Service/BuiltInAI/BuiltInAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class BuiltInAIService: BaseOpenAIService {
}
return model
}

set {
Defaults[.builtInAIModel] = newValue
}
Expand Down
3 changes: 0 additions & 3 deletions Easydict/Swift/Service/Caiyun/CaiyunResponse.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// swiftlint:disable all
//
// CaiyunResponse.swift
// Easydict
Expand All @@ -14,5 +13,3 @@ struct CaiyunResponse: Codable {
var rc: Int
var target: [String]
}

// swiftlint:enable all
4 changes: 0 additions & 4 deletions Easydict/Swift/Service/Caiyun/CaiyunService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
// Copyright © 2023 izual. All rights reserved.
//

// swiftlint:disable all

import Alamofire
import Defaults
import Foundation
Expand Down Expand Up @@ -130,5 +128,3 @@ public final class CaiyunService: QueryService {
enum QueryServiceError: Error {
case notSupported
}

// swiftlint:enable all
4 changes: 0 additions & 4 deletions Easydict/Swift/Service/Caiyun/CaiyunTranslateType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
// Copyright © 2023 izual. All rights reserved.
//

// swiftlint:disable all

import Foundation

struct CaiyunTranslateType: RawRepresentable {
Expand Down Expand Up @@ -60,5 +58,3 @@ struct CaiyunTranslateType: RawRepresentable {
return CaiyunTranslateType(rawValue: "\(from)2\(to)")
}
}

// swiftlint:enable all
3 changes: 0 additions & 3 deletions Easydict/Swift/Service/CustomOpenAI/CustomOpenAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
// Copyright © 2024 izual. All rights reserved.
//

import Alamofire
import CryptoKit
import Defaults
import Foundation

Expand Down Expand Up @@ -41,7 +39,6 @@ class CustomOpenAIService: BaseOpenAIService {
get {
Defaults[.customOpenAIModel]
}

set {
Defaults[.customOpenAIModel] = newValue
}
Expand Down
126 changes: 50 additions & 76 deletions Easydict/Swift/Service/Gemini/GeminiService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
// Copyright © 2024 izual. All rights reserved.
//

// swiftlint:disable all

import Defaults
import Foundation
import GoogleGenerativeAI

// TODO: add a LLM stream service base class, make both OpenAI and Gemini inherit from it.
@objc(EZGeminiService)
public final class GeminiService: QueryService {
public final class GeminiService: LLMStreamService {
// MARK: Public

override public func serviceType() -> ServiceType {
Expand All @@ -29,21 +27,8 @@ public final class GeminiService: QueryService {
NSLocalizedString("gemini_translate", comment: "The name of Gemini Translate")
}

override public func supportLanguagesDictionary() -> MMOrderedDictionary<AnyObject, AnyObject> {
// TODO: Replace MMOrderedDictionary.
let orderedDict = MMOrderedDictionary<AnyObject, AnyObject>()
for language in EZLanguageManager.shared().allLanguages {
let value = language.rawValue
if !GeminiService.unsupportedLanguages.contains(language) {
orderedDict.setObject(value as NSString, forKey: language.rawValue as NSString)
}
}

return orderedDict
}

public override func isStream() -> Bool {
true
override public func queryTextType() -> EZQueryTextType {
[.translation]
}

override public func translate(
Expand All @@ -55,54 +40,45 @@ public final class GeminiService: QueryService {
Task {
do {
let translationPrompt = translationPrompt(text: text, from: from, to: to)
let prompt = QueryService.translationSystemPrompt +
let prompt = LLMStreamService.translationSystemPrompt +
"\n" + translationPrompt
// logInfo("gemini prompt: \(prompt)")
let model = GenerativeModel(
name: "gemini-pro",
apiKey: apiKey,
safetySettings: [
GeminiService.harassmentSafety,
GeminiService.hateSpeechSafety,
GeminiService.sexuallyExplicitSafety,
GeminiService.dangerousContentSafety,
harassmentBlockNone,
hateSpeechBlockNone,
sexuallyExplicitBlockNone,
dangerousContentBlockNone,
]
)

result.isStreamFinished = false

var resultString = ""

// Gemini Docs: https://github.com/google/generative-ai-swift
if #available(macOS 12.0, *) {
result.isStreamFinished = false

var resultString = ""
let outputContentStream = model.generateContentStream(prompt)
let outputContentStream = model.generateContentStream(prompt)
for try await outputContent in outputContentStream {
guard let line = outputContent.text else {
return
}
if !result.isStreamFinished {
resultString += line

for try await outputContent in outputContentStream {
guard let line = outputContent.text else {
return
}
if !result.isStreamFinished {
resultString += line
result.translatedResults = [resultString]
await MainActor.run {
throttler.throttle { [unowned self] in
completion(result, nil)
}
result.translatedResults = [resultString]
await MainActor.run {
throttler.throttle { [unowned self] in
completion(result, nil)
}
}
}
result.isStreamFinished = true
completion(result, nil)
} else {
// Gemini does not support stream in macOS 12.0-
let outputContent = try await model.generateContent(prompt)
guard let resultString = outputContent.text else {
return
}
result.translatedResults = [resultString]
await MainActor.run {
completion(result, nil)
}
}

result.isStreamFinished = true
result.translatedResults = [getFinalResultText(text: resultString)]
completion(result, nil)
} catch {
/**
https://github.com/google/generative-ai-swift/issues/89
Expand All @@ -126,34 +102,32 @@ public final class GeminiService: QueryService {

// MARK: Internal

let throttler = Throttler()

// MARK: Private

// https://ai.google.dev/available_regions
private static let unsupportedLanguages: [Language] = [
.persian,
.filipino,
.khmer,
.lao,
.malay,
.mongolian,
.burmese,
.telugu,
.tamil,
.urdu,
]

// Set Gemini safety level to BLOCK_NONE
private static let harassmentSafety = SafetySetting(harmCategory: .harassment, threshold: .blockNone)
private static let hateSpeechSafety = SafetySetting(harmCategory: .hateSpeech, threshold: .blockNone)
private static let sexuallyExplicitSafety = SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockNone)
private static let dangerousContentSafety = SafetySetting(harmCategory: .dangerousContent, threshold: .blockNone)
override var unsupportedLanguages: [Language] {
[
.persian,
.filipino,
.khmer,
.lao,
.malay,
.mongolian,
.burmese,
.telugu,
.tamil,
.urdu,
]
}

// easydict://writeKeyValue?EZGeminiAPIKey=xxx
private var apiKey: String {
override var apiKey: String {
Defaults[.geminiAPIKey] ?? ""
}
}

// swiftlint:enable all
// MARK: Private

// Set Gemini safety level to BLOCK_NONE
private let harassmentBlockNone = SafetySetting(harmCategory: .harassment, threshold: .blockNone)
private let hateSpeechBlockNone = SafetySetting(harmCategory: .hateSpeech, threshold: .blockNone)
private let sexuallyExplicitBlockNone = SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockNone)
private let dangerousContentBlockNone = SafetySetting(harmCategory: .dangerousContent, threshold: .blockNone)
}
Loading

0 comments on commit 4d9eb5a

Please sign in to comment.