-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: initial TTS support * chore: lint, omitempty * chore: dont use pointer in struct * fix: add mocked server tests to speech_test.go Co-authored-by: Lachlan Laycock <[email protected]> * chore: update imports * chore: fix lint * chore: add an error check * chore: ignore lint * chore: add error checks in package * chore: add test * chore: fix test --------- Co-authored-by: Lachlan Laycock <[email protected]>
- Loading branch information
1 parent
b7cac70
commit 515de02
Showing
3 changed files
with
205 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
package openai | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"io" | ||
"net/http" | ||
) | ||
|
||
type SpeechModel string | ||
|
||
const ( | ||
TTSModel1 SpeechModel = "tts-1" | ||
TTsModel1HD SpeechModel = "tts-1-hd" | ||
) | ||
|
||
type SpeechVoice string | ||
|
||
const ( | ||
VoiceAlloy SpeechVoice = "alloy" | ||
VoiceEcho SpeechVoice = "echo" | ||
VoiceFable SpeechVoice = "fable" | ||
VoiceOnyx SpeechVoice = "onyx" | ||
VoiceNova SpeechVoice = "nova" | ||
VoiceShimmer SpeechVoice = "shimmer" | ||
) | ||
|
||
type SpeechResponseFormat string | ||
|
||
const ( | ||
SpeechResponseFormatMp3 SpeechResponseFormat = "mp3" | ||
SpeechResponseFormatOpus SpeechResponseFormat = "opus" | ||
SpeechResponseFormatAac SpeechResponseFormat = "aac" | ||
SpeechResponseFormatFlac SpeechResponseFormat = "flac" | ||
) | ||
|
||
var ( | ||
ErrInvalidSpeechModel = errors.New("invalid speech model") | ||
ErrInvalidVoice = errors.New("invalid voice") | ||
) | ||
|
||
type CreateSpeechRequest struct { | ||
Model SpeechModel `json:"model"` | ||
Input string `json:"input"` | ||
Voice SpeechVoice `json:"voice"` | ||
ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3 | ||
Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 | ||
} | ||
|
||
func contains[T comparable](s []T, e T) bool { | ||
for _, v := range s { | ||
if v == e { | ||
return true | ||
} | ||
} | ||
return false | ||
} | ||
|
||
func isValidSpeechModel(model SpeechModel) bool { | ||
return contains([]SpeechModel{TTSModel1, TTsModel1HD}, model) | ||
} | ||
|
||
func isValidVoice(voice SpeechVoice) bool { | ||
return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) | ||
} | ||
|
||
func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response io.ReadCloser, err error) { | ||
if !isValidSpeechModel(request.Model) { | ||
err = ErrInvalidSpeechModel | ||
return | ||
} | ||
if !isValidVoice(request.Voice) { | ||
err = ErrInvalidVoice | ||
return | ||
} | ||
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", request.Model), | ||
withBody(request), | ||
withContentType("application/json; charset=utf-8"), | ||
) | ||
if err != nil { | ||
return | ||
} | ||
|
||
response, err = c.sendRequestRaw(req) | ||
|
||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
package openai_test | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"mime" | ||
"net/http" | ||
"os" | ||
"path/filepath" | ||
"testing" | ||
|
||
"github.com/sashabaranov/go-openai" | ||
"github.com/sashabaranov/go-openai/internal/test" | ||
"github.com/sashabaranov/go-openai/internal/test/checks" | ||
) | ||
|
||
func TestSpeechIntegration(t *testing.T) { | ||
client, server, teardown := setupOpenAITestServer() | ||
defer teardown() | ||
|
||
server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) { | ||
dir, cleanup := test.CreateTestDirectory(t) | ||
path := filepath.Join(dir, "fake.mp3") | ||
test.CreateTestFile(t, path) | ||
defer cleanup() | ||
|
||
// audio endpoints only accept POST requests | ||
if r.Method != "POST" { | ||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed) | ||
return | ||
} | ||
|
||
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) | ||
if err != nil { | ||
http.Error(w, "failed to parse media type", http.StatusBadRequest) | ||
return | ||
} | ||
|
||
if mediaType != "application/json" { | ||
http.Error(w, "request is not json", http.StatusBadRequest) | ||
return | ||
} | ||
|
||
// Parse the JSON body of the request | ||
var params map[string]interface{} | ||
err = json.NewDecoder(r.Body).Decode(¶ms) | ||
if err != nil { | ||
http.Error(w, "failed to parse request body", http.StatusBadRequest) | ||
return | ||
} | ||
|
||
// Check if each required field is present in the parsed JSON object | ||
reqParams := []string{"model", "input", "voice"} | ||
for _, param := range reqParams { | ||
_, ok := params[param] | ||
if !ok { | ||
http.Error(w, fmt.Sprintf("no %s in params", param), http.StatusBadRequest) | ||
return | ||
} | ||
} | ||
|
||
// read audio file content | ||
audioFile, err := os.ReadFile(path) | ||
if err != nil { | ||
http.Error(w, "failed to read audio file", http.StatusInternalServerError) | ||
return | ||
} | ||
|
||
// write audio file content to response | ||
w.Header().Set("Content-Type", "audio/mpeg") | ||
w.Header().Set("Transfer-Encoding", "chunked") | ||
w.Header().Set("Connection", "keep-alive") | ||
_, err = w.Write(audioFile) | ||
if err != nil { | ||
http.Error(w, "failed to write body", http.StatusInternalServerError) | ||
return | ||
} | ||
}) | ||
|
||
t.Run("happy path", func(t *testing.T) { | ||
res, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ | ||
Model: openai.TTSModel1, | ||
Input: "Hello!", | ||
Voice: openai.VoiceAlloy, | ||
}) | ||
checks.NoError(t, err, "CreateSpeech error") | ||
defer res.Close() | ||
|
||
buf, err := io.ReadAll(res) | ||
checks.NoError(t, err, "ReadAll error") | ||
|
||
// save buf to file as mp3 | ||
err = os.WriteFile("test.mp3", buf, 0644) | ||
checks.NoError(t, err, "Create error") | ||
}) | ||
t.Run("invalid model", func(t *testing.T) { | ||
_, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ | ||
Model: "invalid_model", | ||
Input: "Hello!", | ||
Voice: openai.VoiceAlloy, | ||
}) | ||
checks.ErrorIs(t, err, openai.ErrInvalidSpeechModel, "CreateSpeech error") | ||
}) | ||
|
||
t.Run("invalid voice", func(t *testing.T) { | ||
_, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ | ||
Model: openai.TTSModel1, | ||
Input: "Hello!", | ||
Voice: "invalid_voice", | ||
}) | ||
checks.ErrorIs(t, err, openai.ErrInvalidVoice, "CreateSpeech error") | ||
}) | ||
} |