123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- // Copyright 2014 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package main
- import (
- "fmt"
- "log"
- "net/http"
- "os"
- "path/filepath"
- "strings"
- "google.golang.org/api/googleapi"
- prediction "google.golang.org/api/prediction/v1.6"
- )
- func init() {
- scopes := []string{
- prediction.DevstorageFullControlScope,
- prediction.DevstorageReadOnlyScope,
- prediction.DevstorageReadWriteScope,
- prediction.PredictionScope,
- }
- registerDemo("prediction", strings.Join(scopes, " "), predictionMain)
- }
- type predictionType struct {
- api *prediction.Service
- projectNumber string
- bucketName string
- trainingFileName string
- modelName string
- }
- // This example demonstrates calling the Prediction API.
- // Training data is uploaded to a pre-created Google Cloud Storage Bucket and
- // then the Prediction API is called to train a model based on that data.
- // After a few minutes, the model should be completely trained and ready
- // for prediction. At that point, text is sent to the model and the Prediction
- // API attempts to classify the data, and the results are printed out.
- //
- // To get started, follow the instructions found in the "Hello Prediction!"
- // Getting Started Guide located here:
- // https://developers.google.com/prediction/docs/hello_world
- //
- // Example usage:
- // go-api-demo -clientid="my-clientid" -secret="my-secret" prediction
- // my-project-number my-bucket-name my-training-filename my-model-name
- //
- // Example output:
- // Predict result: language=Spanish
- // English Score: 0.000000
- // French Score: 0.000000
- // Spanish Score: 1.000000
- // analyze: output feature text=&{157 English}
- // analyze: output feature text=&{149 French}
- // analyze: output feature text=&{100 Spanish}
- // feature text count=406
- func predictionMain(client *http.Client, argv []string) {
- if len(argv) != 4 {
- fmt.Fprintln(os.Stderr,
- "Usage: prediction project_number bucket training_data model_name")
- return
- }
- api, err := prediction.New(client)
- if err != nil {
- log.Fatalf("unable to create prediction API client: %v", err)
- }
- t := &predictionType{
- api: api,
- projectNumber: argv[0],
- bucketName: argv[1],
- trainingFileName: argv[2],
- modelName: argv[3],
- }
- t.trainModel()
- t.predictModel()
- }
- func (t *predictionType) trainModel() {
- // First, check to see if our trained model already exists.
- res, err := t.api.Trainedmodels.Get(t.projectNumber, t.modelName).Do()
- if err != nil {
- if ae, ok := err.(*googleapi.Error); ok && ae.Code != http.StatusNotFound {
- log.Fatalf("error getting trained model: %v", err)
- }
- log.Printf("Training model not found, creating new model.")
- res, err = t.api.Trainedmodels.Insert(t.projectNumber, &prediction.Insert{
- Id: t.modelName,
- StorageDataLocation: filepath.Join(t.bucketName, t.trainingFileName),
- }).Do()
- if err != nil {
- log.Fatalf("unable to create trained model: %v", err)
- }
- }
- if res.TrainingStatus != "DONE" {
- // Wait for the trained model to finish training.
- fmt.Printf("Training model. Please wait and re-run program after a few minutes.")
- os.Exit(0)
- }
- }
- func (t *predictionType) predictModel() {
- // Model has now been trained. Predict with it.
- input := &prediction.Input{
- Input: &prediction.InputInput{
- CsvInstance: []interface{}{
- "Hola, con quien hablo",
- },
- },
- }
- res, err := t.api.Trainedmodels.Predict(t.projectNumber, t.modelName, input).Do()
- if err != nil {
- log.Fatalf("unable to get trained prediction: %v", err)
- }
- fmt.Printf("Predict result: language=%v\n", res.OutputLabel)
- for _, m := range res.OutputMulti {
- fmt.Printf("%v Score: %v\n", m.Label, m.Score)
- }
- // Now analyze the model.
- an, err := t.api.Trainedmodels.Analyze(t.projectNumber, t.modelName).Do()
- if err != nil {
- log.Fatalf("unable to analyze trained model: %v", err)
- }
- for _, f := range an.DataDescription.OutputFeature.Text {
- fmt.Printf("analyze: output feature text=%v\n", f)
- }
- for _, f := range an.DataDescription.Features {
- fmt.Printf("feature text count=%v\n", f.Text.Count)
- }
- }
|