From 8755e0128e91b8c7882b9eeb79ac0933be4aeb4b Mon Sep 17 00:00:00 2001 From: "liron.levin" Date: Tue, 14 Nov 2023 12:05:08 +0200 Subject: [PATCH] Add implemenation for get fine-tunning jobs Missing implementation for GET /fine_tuning/jobs --- fine_tuning_job.go | 42 ++++++++++++++++++++++++-- fine_tuning_job_test.go | 67 +++++++++++++++++++++++++++++------------ 2 files changed, 87 insertions(+), 22 deletions(-) diff --git a/fine_tuning_job.go b/fine_tuning_job.go index 9dcb49de1..11c571a3d 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -7,6 +7,14 @@ import ( "net/url" ) +// FineTuningJobList is a list of fine-tune jobs. +type FineTuningJobList struct { + FineTuningJobs []FineTuningJob `json:"data"` + HasMore bool `json:"has_more"` + + httpHeader +} + type FineTuningJob struct { ID string `json:"id"` Object string `json:"object"` @@ -55,13 +63,43 @@ type FineTuningJobEvent struct { Type string `json:"type"` } +const fineTuningJobSuffix = "/fine_tuning/jobs" + // CreateFineTuningJob create a fine tuning job. func (c *Client) CreateFineTuningJob( ctx context.Context, request FineTuningJobRequest, ) (response FineTuningJob, err error) { - urlSuffix := "/fine_tuning/jobs" - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(fineTuningJobSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ListFineTuningJobs Lists the fine-tuning jobs +func (c *Client) ListFineTuningJobs( + ctx context.Context, + limit *int, + after *string, +) (response FineTuningJobList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if after != nil { + urlValues.Add("after", *after) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", fineTuningJobSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index d2fbcd4c7..c293598e4 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -13,31 +13,48 @@ import ( const testFineTuninigJobID = "fine-tuning-job-id" -// TestFineTuningJob Tests the fine tuning job endpoint of the API using the mocked server. +// TestFineTuningJob Tests the fine-tuning job endpoint of the API using the mocked server. func TestFineTuningJob(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler( "/v1/fine_tuning/jobs", - func(w http.ResponseWriter, _ *http.Request) { - resBytes, _ := json.Marshal(openai.FineTuningJob{ - Object: "fine_tuning.job", - ID: testFineTuninigJobID, - Model: "davinci-002", - CreatedAt: 1692661014, - FinishedAt: 1692661190, - FineTunedModel: "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", - OrganizationID: "org-123", - ResultFiles: []string{"file-abc123"}, - Status: "succeeded", - ValidationFile: "", - TrainingFile: "file-abc123", - Hyperparameters: openai.Hyperparameters{ - Epochs: "auto", - }, - TrainedTokens: 5768, - }) - fmt.Fprintln(w, string(resBytes)) + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.FineTuningJob{ + Object: "fine_tuning.job", + ID: testFineTuninigJobID, + Model: "davinci-002", + CreatedAt: 1692661014, + FinishedAt: 1692661190, + FineTunedModel: "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", + OrganizationID: "org-123", + ResultFiles: []string{"file-abc123"}, + Status: "succeeded", + ValidationFile: "", + TrainingFile: "file-abc123", + Hyperparameters: openai.Hyperparameters{ + Epochs: "auto", + }, + TrainedTokens: 5768, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.FineTuningJobList{ + FineTuningJobs: []openai.FineTuningJob{ + { + Object: "fine_tuning.job", + ID: testFineTuninigJobID, + Model: "davinci-002", + CreatedAt: 1692661014, + FinishedAt: 1692661190, + FineTunedModel: "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", + OrganizationID: "org-123", + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } }, ) @@ -94,6 +111,16 @@ func TestFineTuningJob(t *testing.T) { ) checks.NoError(t, err, "ListFineTuningJobEvents error") + jobs, err := client.ListFineTuningJobs( + ctx, + nil, + nil, + ) + checks.NoError(t, err, "ListFineTuningJobs error") + if len(jobs.FineTuningJobs) != 1 { + t.Errorf("no fine tuning jobs found") + } + _, err = client.ListFineTuningJobEvents( ctx, testFineTuninigJobID,